From dd829d771af56155d5f3b101e061c2d946691012 Mon Sep 17 00:00:00 2001 From: Divided by Zer0 Date: Tue, 16 Jul 2024 16:58:17 +0200 Subject: [PATCH] feat: Updates Facerestore node to latest version. (#284) * facerestore_cf commit * Add missing nodes * feat: ensure helper models downloaded by model manager * fix: moved downloads to gfpgan manager * fix: Adds clean implementation of gfpgan loading * chore: license * chore: only gitignore `hordelib/models/` (and not other models/ dirs) * fix: add missing models/ files for facerestore * tests: more stringent histogram fail for post processor image checks * chore: clarify sourcing and licensing of facerestore_cf * style: fix style: fix types for pipline map --------- Co-authored-by: tazlin --- .gitignore | 2 +- hordelib/comfy_horde.py | 8 + hordelib/horde.py | 5 +- hordelib/nodes/facerestore/__init__.py | 204 ---- .../yolov5face/utils/extract_ckpt.py | 5 - .../facerestore/facelib/utils/__init__.py | 7 - hordelib/nodes/facerestore_cf/LICENSE | 674 +++++++++++ hordelib/nodes/facerestore_cf/README.md | 1 + hordelib/nodes/facerestore_cf/__init__.py | 333 ++++++ hordelib/nodes/facerestore_cf/basicsr/VERSION | 1 + .../nodes/facerestore_cf/basicsr/__init__.py | 11 + .../facerestore_cf/basicsr/archs/__init__.py | 27 + .../basicsr/archs/arcface_arch.py | 252 ++++ .../facerestore_cf/basicsr/archs/arch_util.py | 342 ++++++ .../basicsr/archs/codeformer_arch.py | 291 +++++ .../basicsr/archs/rrdbnet_arch.py | 120 ++ .../facerestore_cf/basicsr/archs/vgg_arch.py | 264 +++++ .../basicsr/archs/vqgan_arch.py | 456 ++++++++ .../facerestore_cf/basicsr/data/__init__.py | 103 ++ .../basicsr/data/data_sampler.py | 49 + .../facerestore_cf/basicsr/data/data_util.py | 313 +++++ .../basicsr/data/prefetch_dataloader.py | 126 ++ .../facerestore_cf/basicsr/data/transforms.py | 170 +++ .../facerestore_cf/basicsr/losses/__init__.py | 43 + .../basicsr/losses/loss_util.py | 96 ++ .../facerestore_cf/basicsr/losses/losses.py | 470 ++++++++ .../basicsr/metrics/__init__.py | 20 + .../basicsr/metrics/metric_util.py | 45 + .../basicsr/metrics/psnr_ssim.py | 128 ++ .../facerestore_cf/basicsr/models/__init__.py | 30 + .../basicsr/ops}/__init__.py | 0 .../basicsr/ops/dcn/__init__.py | 17 + .../basicsr/ops/dcn/deform_conv.py | 503 ++++++++ .../basicsr/ops/dcn/src/deform_conv_cuda.cpp | 685 +++++++++++ .../ops/dcn/src/deform_conv_cuda_kernel.cu | 867 ++++++++++++++ .../basicsr/ops/dcn/src/deform_conv_ext.cpp | 164 +++ .../basicsr/ops/fused_act/__init__.py | 3 + .../basicsr/ops/fused_act/fused_act.py | 98 ++ .../ops/fused_act/src/fused_bias_act.cpp | 26 + .../fused_act/src/fused_bias_act_kernel.cu | 100 ++ .../basicsr/ops/upfirdn2d/__init__.py | 3 + .../basicsr/ops/upfirdn2d/src/upfirdn2d.cpp | 24 + .../ops/upfirdn2d/src/upfirdn2d_kernel.cu | 370 ++++++ .../basicsr/ops/upfirdn2d/upfirdn2d.py | 188 +++ .../nodes/facerestore_cf/basicsr/setup.py | 171 +++ .../nodes/facerestore_cf/basicsr/train.py | 250 ++++ .../facerestore_cf/basicsr/utils/__init__.py | 29 + .../facerestore_cf/basicsr/utils/dist_util.py | 83 ++ .../basicsr/utils/download_util.py | 83 ++ .../basicsr/utils/file_client.py | 172 +++ .../facerestore_cf/basicsr/utils/img_util.py | 171 +++ .../facerestore_cf/basicsr/utils/lmdb_util.py | 200 ++++ .../facerestore_cf/basicsr/utils/logger.py | 174 +++ .../basicsr/utils/matlab_functions.py | 372 ++++++ .../facerestore_cf/basicsr/utils/misc.py | 136 +++ .../facerestore_cf/basicsr/utils/options.py | 109 ++ .../basicsr/utils/realesrgan_utils.py | 319 +++++ .../facerestore_cf/basicsr/utils/registry.py | 83 ++ .../nodes/facerestore_cf/basicsr/version.py | 5 + .../facelib}/__init__.py | 0 .../facelib/detection/__init__.py | 232 ++-- .../facelib/detection/align_trans.py | 452 +++---- .../facelib/detection/matlab_cp2tform.py | 633 +++++----- .../detection/retinaface/retinaface.py | 809 +++++++------ .../detection/retinaface/retinaface_net.py | 396 +++---- .../detection/retinaface/retinaface_utils.py | 841 +++++++------ .../facelib/detection/yolov5face}/__init__.py | 0 .../detection/yolov5face/face_detector.py | 323 +++-- .../detection/yolov5face/models}/__init__.py | 0 .../detection/yolov5face/models/common.py | 33 +- .../yolov5face/models/experimental.py | 15 +- .../detection/yolov5face/models/yolo.py | 123 +- .../detection/yolov5face/models/yolov5l.yaml | 0 .../detection/yolov5face/models/yolov5n.yaml | 0 .../detection/yolov5face/utils/__init__.py | 0 .../detection/yolov5face/utils/autoanchor.py | 24 +- .../detection/yolov5face/utils/datasets.py | 70 +- .../yolov5face/utils/extract_ckpt.py | 7 + .../detection/yolov5face/utils/general.py | 542 ++++----- .../detection/yolov5face/utils/torch_utils.py | 80 +- .../facelib/parsing/__init__.py | 59 +- .../facelib/parsing/bisenet.py | 280 ++--- .../facelib/parsing/parsenet.py | 393 ++++--- .../facelib/parsing/resnet.py | 138 +-- .../facerestore_cf/facelib/utils/__init__.py | 13 + .../facelib/utils/face_restoration_helper.py | 1037 ++++++++--------- .../facelib/utils/face_utils.py | 558 +++++---- .../facelib/utils/misc.py | 275 +++-- .../nodes/facerestore_cf/r_chainner/README.md | 3 + .../facerestore_cf/r_chainner/__init__.py | 0 .../r_chainner/gfpganv1_clean_arch.py | 370 ++++++ .../r_chainner/model_loading.py | 29 + .../r_chainner/stylegan2_clean_arch.py | 453 +++++++ .../nodes/facerestore_cf/r_chainner/types.py | 19 + hordelib/nodes/node_model_loader.py | 2 +- .../pipeline_image_facefix.json | 147 +-- .../pipelines/pipeline_image_facefix.json | 31 +- mypy.ini | 6 +- pyproject.toml | 2 +- requirements.txt | 1 + tests/test_horde_pp.py | 4 +- 101 files changed, 14322 insertions(+), 4079 deletions(-) delete mode 100644 hordelib/nodes/facerestore/__init__.py delete mode 100644 hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/extract_ckpt.py delete mode 100644 hordelib/nodes/facerestore/facelib/utils/__init__.py create mode 100644 hordelib/nodes/facerestore_cf/LICENSE create mode 100644 hordelib/nodes/facerestore_cf/README.md create mode 100644 hordelib/nodes/facerestore_cf/__init__.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/VERSION create mode 100644 hordelib/nodes/facerestore_cf/basicsr/__init__.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/archs/__init__.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/archs/arcface_arch.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/archs/arch_util.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/archs/codeformer_arch.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/archs/rrdbnet_arch.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/archs/vgg_arch.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/archs/vqgan_arch.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/data/__init__.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/data/data_sampler.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/data/data_util.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/data/prefetch_dataloader.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/data/transforms.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/losses/__init__.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/losses/loss_util.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/losses/losses.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/metrics/__init__.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/metrics/metric_util.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/metrics/psnr_ssim.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/models/__init__.py rename hordelib/nodes/{facerestore/facelib => facerestore_cf/basicsr/ops}/__init__.py (100%) create mode 100644 hordelib/nodes/facerestore_cf/basicsr/ops/dcn/__init__.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/ops/dcn/deform_conv.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/ops/dcn/src/deform_conv_cuda.cpp create mode 100644 hordelib/nodes/facerestore_cf/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu create mode 100644 hordelib/nodes/facerestore_cf/basicsr/ops/dcn/src/deform_conv_ext.cpp create mode 100644 hordelib/nodes/facerestore_cf/basicsr/ops/fused_act/__init__.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/ops/fused_act/fused_act.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/ops/fused_act/src/fused_bias_act.cpp create mode 100644 hordelib/nodes/facerestore_cf/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu create mode 100644 hordelib/nodes/facerestore_cf/basicsr/ops/upfirdn2d/__init__.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp create mode 100644 hordelib/nodes/facerestore_cf/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu create mode 100644 hordelib/nodes/facerestore_cf/basicsr/ops/upfirdn2d/upfirdn2d.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/setup.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/train.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/utils/__init__.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/utils/dist_util.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/utils/download_util.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/utils/file_client.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/utils/img_util.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/utils/lmdb_util.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/utils/logger.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/utils/matlab_functions.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/utils/misc.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/utils/options.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/utils/realesrgan_utils.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/utils/registry.py create mode 100644 hordelib/nodes/facerestore_cf/basicsr/version.py rename hordelib/nodes/{facerestore/facelib/detection/yolov5face => facerestore_cf/facelib}/__init__.py (100%) rename hordelib/nodes/{facerestore => facerestore_cf}/facelib/detection/__init__.py (83%) rename hordelib/nodes/{facerestore => facerestore_cf}/facelib/detection/align_trans.py (75%) rename hordelib/nodes/{facerestore => facerestore_cf}/facelib/detection/matlab_cp2tform.py (88%) rename hordelib/nodes/{facerestore => facerestore_cf}/facelib/detection/retinaface/retinaface.py (83%) rename hordelib/nodes/{facerestore => facerestore_cf}/facelib/detection/retinaface/retinaface_net.py (90%) rename hordelib/nodes/{facerestore => facerestore_cf}/facelib/detection/retinaface/retinaface_utils.py (92%) rename hordelib/nodes/{facerestore/facelib/detection/yolov5face/models => facerestore_cf/facelib/detection/yolov5face}/__init__.py (100%) rename hordelib/nodes/{facerestore => facerestore_cf}/facelib/detection/yolov5face/face_detector.py (71%) rename hordelib/nodes/{facerestore/facelib/detection/yolov5face/utils => facerestore_cf/facelib/detection/yolov5face/models}/__init__.py (100%) rename hordelib/nodes/{facerestore => facerestore_cf}/facelib/detection/yolov5face/models/common.py (92%) rename hordelib/nodes/{facerestore => facerestore_cf}/facelib/detection/yolov5face/models/experimental.py (78%) rename hordelib/nodes/{facerestore => facerestore_cf}/facelib/detection/yolov5face/models/yolo.py (64%) rename hordelib/nodes/{facerestore => facerestore_cf}/facelib/detection/yolov5face/models/yolov5l.yaml (100%) rename hordelib/nodes/{facerestore => facerestore_cf}/facelib/detection/yolov5face/models/yolov5n.yaml (100%) create mode 100644 hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/__init__.py rename hordelib/nodes/{facerestore => facerestore_cf}/facelib/detection/yolov5face/utils/autoanchor.py (97%) rename hordelib/nodes/{facerestore => facerestore_cf}/facelib/detection/yolov5face/utils/datasets.py (97%) create mode 100644 hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/extract_ckpt.py rename hordelib/nodes/{facerestore => facerestore_cf}/facelib/detection/yolov5face/utils/general.py (97%) rename hordelib/nodes/{facerestore => facerestore_cf}/facelib/detection/yolov5face/utils/torch_utils.py (97%) rename hordelib/nodes/{facerestore => facerestore_cf}/facelib/parsing/__init__.py (81%) rename hordelib/nodes/{facerestore => facerestore_cf}/facelib/parsing/bisenet.py (85%) rename hordelib/nodes/{facerestore => facerestore_cf}/facelib/parsing/parsenet.py (70%) rename hordelib/nodes/{facerestore => facerestore_cf}/facelib/parsing/resnet.py (97%) create mode 100644 hordelib/nodes/facerestore_cf/facelib/utils/__init__.py rename hordelib/nodes/{facerestore => facerestore_cf}/facelib/utils/face_restoration_helper.py (78%) rename hordelib/nodes/{facerestore => facerestore_cf}/facelib/utils/face_utils.py (93%) rename hordelib/nodes/{facerestore => facerestore_cf}/facelib/utils/misc.py (70%) create mode 100644 hordelib/nodes/facerestore_cf/r_chainner/README.md create mode 100644 hordelib/nodes/facerestore_cf/r_chainner/__init__.py create mode 100644 hordelib/nodes/facerestore_cf/r_chainner/gfpganv1_clean_arch.py create mode 100644 hordelib/nodes/facerestore_cf/r_chainner/model_loading.py create mode 100644 hordelib/nodes/facerestore_cf/r_chainner/stylegan2_clean_arch.py create mode 100644 hordelib/nodes/facerestore_cf/r_chainner/types.py diff --git a/.gitignore b/.gitignore index 9ca20c67..46594cae 100644 --- a/.gitignore +++ b/.gitignore @@ -20,7 +20,7 @@ parts/ sdist/ var/ wheels/ -models/ +hordelib/models/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ diff --git a/hordelib/comfy_horde.py b/hordelib/comfy_horde.py index 6dbc9f40..996132b8 100644 --- a/hordelib/comfy_horde.py +++ b/hordelib/comfy_horde.py @@ -410,6 +410,14 @@ def _set_comfyui_paths(self) -> None: _comfy_supported_pt_extensions, ) + _comfy_folder_names_and_paths["facerestore_models"] = ( + [ + str(UserSettings.get_model_directory() / "gfpgan"), + str(UserSettings.get_model_directory() / "codeformer"), + ], + _comfy_supported_pt_extensions, + ) + _comfy_folder_names_and_paths["controlnet"] = ( [ _comfy_folder_names_and_paths["controlnet"][0][0], diff --git a/hordelib/horde.py b/hordelib/horde.py index ddad5ae4..404dea7e 100644 --- a/hordelib/horde.py +++ b/hordelib/horde.py @@ -214,7 +214,7 @@ class HordeLib: } # pipeline parameter <- hordelib payload parameter mapping - PAYLOAD_TO_PIPELINE_PARAMETER_MAPPING = { # FIXME + PAYLOAD_TO_PIPELINE_PARAMETER_MAPPING: dict[str, str | Callable] = { # FIXME "sampler.sampler_name": "sampler_name", "sampler.cfg": "cfg_scale", "sampler.denoise": "denoising_strength", @@ -820,6 +820,9 @@ def _final_pipeline_adjustments(self, payload, pipeline_data) -> tuple[dict, lis # values for steps on things like stable cascade if isinstance(key, FunctionType): pipeline_params[newkey] = key(payload) + elif not isinstance(key, str): + logger.error(f"Invalid key {key}") + raise RuntimeError(f"Invalid key {key}") elif "*" in key: key, multiplier = key.split("*", 1) elif key in payload: diff --git a/hordelib/nodes/facerestore/__init__.py b/hordelib/nodes/facerestore/__init__.py deleted file mode 100644 index b2fdac14..00000000 --- a/hordelib/nodes/facerestore/__init__.py +++ /dev/null @@ -1,204 +0,0 @@ -import os -import model_management -import torch -import comfy.utils -import numpy as np -import cv2 -import math -from hordelib.nodes.facerestore.facelib.utils.face_restoration_helper import FaceRestoreHelper -from hordelib.nodes.facerestore.facelib.detection.retinaface import retinaface -from torchvision.transforms.functional import normalize -import threading -from loguru import logger - - -def img2tensor(imgs, bgr2rgb=True, float32=True): - """Numpy array to tensor. - - Args: - imgs (list[ndarray] | ndarray): Input images. - bgr2rgb (bool): Whether to change bgr to rgb. - float32 (bool): Whether to change to float32. - - Returns: - list[tensor] | tensor: Tensor images. If returned results only have - one element, just return tensor. - """ - - def _totensor(img, bgr2rgb, float32): - if img.shape[2] == 3 and bgr2rgb: - if img.dtype == "float64": - img = img.astype("float32") - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - img = torch.from_numpy(img.transpose(2, 0, 1)) - if float32: - img = img.float() - return img - - if isinstance(imgs, list): - return [_totensor(img, bgr2rgb, float32) for img in imgs] - else: - return _totensor(imgs, bgr2rgb, float32) - - -def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): - """Convert torch Tensors into image numpy arrays. - - After clamping to [min, max], values will be normalized to [0, 1]. - - Args: - tensor (Tensor or list[Tensor]): Accept shapes: - 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); - 2) 3D Tensor of shape (3/1 x H x W); - 3) 2D Tensor of shape (H x W). - Tensor channel should be in RGB order. - rgb2bgr (bool): Whether to change rgb to bgr. - out_type (numpy type): output types. If ``np.uint8``, transform outputs - to uint8 type with range [0, 255]; otherwise, float type with - range [0, 1]. Default: ``np.uint8``. - min_max (tuple[int]): min and max values for clamp. - - Returns: - (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of - shape (H x W). The channel order is BGR. - """ - if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): - raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}") - - if torch.is_tensor(tensor): - tensor = [tensor] - result = [] - for _tensor in tensor: - _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) - _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) - - n_dim = _tensor.dim() - if n_dim == 4: - img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() - img_np = img_np.transpose(1, 2, 0) - if rgb2bgr: - img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) - elif n_dim == 3: - img_np = _tensor.numpy() - img_np = img_np.transpose(1, 2, 0) - if img_np.shape[2] == 1: # gray image - img_np = np.squeeze(img_np, axis=2) - else: - if rgb2bgr: - img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) - elif n_dim == 2: - img_np = _tensor.numpy() - else: - raise TypeError("Only support 4D, 3D or 2D tensor. " f"But received with dimension: {n_dim}") - if out_type == np.uint8: - # Unlike MATLAB, numpy.unit8() WILL NOT round by default. - img_np = (img_np * 255.0).round() - img_np = img_np.astype(out_type) - result.append(img_np) - if len(result) == 1: - result = result[0] - return result - - -class FaceRestoreWithModel: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "upscale_model": ("UPSCALE_MODEL",), - "image": ("IMAGE",), - "facedetection": ( - [ - "retinaface_resnet50", - "retinaface_mobile0.25", - "YOLOv5l", - "YOLOv5n", - ], - ), - } - } - - RETURN_TYPES = ("IMAGE",) - - FUNCTION = "restore_face" - - CATEGORY = "facerestore" - - def restore_face(self, upscale_model, image, facedetection): - # logger.warning(f"mutex:{id(FaceRestoreWithModel._mutex):x} Facerestore with upscale_model {id(upscale_model):x} and detection model {id(facedetection):x} and image {id(image):x}") - # with FaceRestoreWithModel._mutex: - # facedetection = copy.deepcopy(facedetection) - - device = model_management.get_torch_device() - upscale_model.to(device) - face_helper = FaceRestoreHelper( - 1, - face_size=512, - crop_ratio=(1, 1), - det_model=facedetection, - save_ext="png", - use_parse=True, - device=device, - ) - - image_np = 255.0 * image.cpu().numpy().squeeze() - - image_np = image_np[:, :, ::-1] - - original_resolution = image_np.shape[0:2] - - if upscale_model is None or face_helper is None: - return image - - face_helper.clean_all() - face_helper.read_image(image_np) - face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) - face_helper.align_warp_face() - restored_face = None - - for idx, cropped_face in enumerate(face_helper.cropped_faces): - cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True) - normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) - cropped_face_t = cropped_face_t.unsqueeze(0).to(device) - - try: - with torch.no_grad(): - # output = upscale_model(cropped_face_t, w=strength, adain=True)[0] - output = upscale_model(cropped_face_t)[0] - restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) - del output - # torch.cuda.empty_cache() - except Exception as error: - logger.error(f"Failed inference for CodeFormer: {error}") - restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) - - restored_face = restored_face.astype("uint8") - face_helper.add_restored_face(restored_face) - - face_helper.get_inverse_affine(None) - - restored_img = face_helper.paste_faces_to_input_image() - restored_img = restored_img[:, :, ::-1] - - if original_resolution != restored_img.shape[0:2]: - restored_img = cv2.resize( - restored_img, - (0, 0), - fx=original_resolution[1] / restored_img.shape[1], - fy=original_resolution[0] / restored_img.shape[0], - interpolation=cv2.INTER_LINEAR, - ) - - face_helper.clean_all() - - # restored_img = cv2.cvtColor(restored_face, cv2.COLOR_BGR2RGB) - - restored_img_np = np.array(restored_img).astype(np.float32) / 255.0 - restored_img_tensor = torch.from_numpy(restored_img_np).unsqueeze(0) - - return (restored_img_tensor,) - - -NODE_CLASS_MAPPINGS = { - "FaceRestoreWithModel": FaceRestoreWithModel, -} diff --git a/hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/extract_ckpt.py b/hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/extract_ckpt.py deleted file mode 100644 index 07e780c7..00000000 --- a/hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/extract_ckpt.py +++ /dev/null @@ -1,5 +0,0 @@ -import torch -import sys -sys.path.insert(0,'./facelib/detection/yolov5face') -model = torch.load('facelib/detection/yolov5face/yolov5n-face.pt', map_location='cpu')['model'] -torch.save(model.state_dict(),'../../models/facedetection') diff --git a/hordelib/nodes/facerestore/facelib/utils/__init__.py b/hordelib/nodes/facerestore/facelib/utils/__init__.py deleted file mode 100644 index 23ef0352..00000000 --- a/hordelib/nodes/facerestore/facelib/utils/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back -from .misc import img2tensor, load_file_from_url, download_pretrained_models, scandir - -__all__ = [ - 'align_crop_face_landmarks', 'compute_increased_bbox', 'get_valid_bboxes', 'load_file_from_url', - 'download_pretrained_models', 'paste_face_back', 'img2tensor', 'scandir' -] diff --git a/hordelib/nodes/facerestore_cf/LICENSE b/hordelib/nodes/facerestore_cf/LICENSE new file mode 100644 index 00000000..f288702d --- /dev/null +++ b/hordelib/nodes/facerestore_cf/LICENSE @@ -0,0 +1,674 @@ + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. diff --git a/hordelib/nodes/facerestore_cf/README.md b/hordelib/nodes/facerestore_cf/README.md new file mode 100644 index 00000000..abe522cb --- /dev/null +++ b/hordelib/nodes/facerestore_cf/README.md @@ -0,0 +1 @@ +Packaged code in this directory (unless stated otherwise) licensed under GPL and sourced from https://github.com/mav-rik/facerestore_cf. See LICENSE for more information. diff --git a/hordelib/nodes/facerestore_cf/__init__.py b/hordelib/nodes/facerestore_cf/__init__.py new file mode 100644 index 00000000..81c6c1c6 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/__init__.py @@ -0,0 +1,333 @@ +import math +import os +import sys + +import comfy.utils +import cv2 +import folder_paths +import model_management +import numpy as np +import torch +# from comfy_extras.chainner_models import model_loading +from hordelib.nodes.facerestore_cf.r_chainner import model_loading +from torchvision.transforms.functional import normalize + +from hordelib.nodes.facerestore_cf.basicsr.utils.registry import ARCH_REGISTRY +from hordelib.nodes.facerestore_cf.facelib.detection.retinaface import retinaface +from hordelib.nodes.facerestore_cf.facelib.utils.face_restoration_helper import FaceRestoreHelper + +# import codeformer_arch + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == "float64": + img = img.astype("float32") + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): + """Convert torch Tensors into image numpy arrays. + + After clamping to [min, max], values will be normalized to [0, 1]. + + Args: + tensor (Tensor or list[Tensor]): Accept shapes: + 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); + 2) 3D Tensor of shape (3/1 x H x W); + 3) 2D Tensor of shape (H x W). + Tensor channel should be in RGB order. + rgb2bgr (bool): Whether to change rgb to bgr. + out_type (numpy type): output types. If ``np.uint8``, transform outputs + to uint8 type with range [0, 255]; otherwise, float type with + range [0, 1]. Default: ``np.uint8``. + min_max (tuple[int]): min and max values for clamp. + + Returns: + (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of + shape (H x W). The channel order is BGR. + """ + if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): + raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}") + + if torch.is_tensor(tensor): + tensor = [tensor] + result = [] + for _tensor in tensor: + _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) + _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) + + n_dim = _tensor.dim() + if n_dim == 4: + img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() + img_np = img_np.transpose(1, 2, 0) + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 3: + img_np = _tensor.numpy() + img_np = img_np.transpose(1, 2, 0) + if img_np.shape[2] == 1: # gray image + img_np = np.squeeze(img_np, axis=2) + else: + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 2: + img_np = _tensor.numpy() + else: + raise TypeError("Only support 4D, 3D or 2D tensor. " f"But received with dimension: {n_dim}") + if out_type == np.uint8: + # Unlike MATLAB, numpy.unit8() WILL NOT round by default. + img_np = (img_np * 255.0).round() + img_np = img_np.astype(out_type) + result.append(img_np) + if len(result) == 1: + result = result[0] + return result + + +class FaceRestoreCFWithModel: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "facerestore_model": ("FACERESTORE_MODEL",), + "image": ("IMAGE",), + "facedetection": (["retinaface_resnet50", "retinaface_mobile0.25", "YOLOv5l", "YOLOv5n"],), + "codeformer_fidelity": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1, "step": 0.05}), + }, + } + + RETURN_TYPES = ("IMAGE",) + + FUNCTION = "restore_face" + + CATEGORY = "facerestore_cf" + + def __init__(self): + self.face_helper = None + + def restore_face(self, facerestore_model, image, facedetection, codeformer_fidelity): + print(f"\tStarting restore_face with codeformer_fidelity: {codeformer_fidelity}") + device = model_management.get_torch_device() + facerestore_model.to(device) + if self.face_helper is None: + self.face_helper = FaceRestoreHelper( + 1, + face_size=512, + crop_ratio=(1, 1), + det_model=facedetection, + save_ext="png", + use_parse=True, + device=device, + ) + + image_np = 255.0 * image.cpu().numpy() + + total_images = image_np.shape[0] + out_images = np.ndarray(shape=image_np.shape) + + for i in range(total_images): + cur_image_np = image_np[i, :, :, ::-1] + + original_resolution = cur_image_np.shape[0:2] + + if facerestore_model is None or self.face_helper is None: + return image + + self.face_helper.clean_all() + self.face_helper.read_image(cur_image_np) + self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) + self.face_helper.align_warp_face() + + restored_face = None + for idx, cropped_face in enumerate(self.face_helper.cropped_faces): + cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True) + normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + cropped_face_t = cropped_face_t.unsqueeze(0).to(device) + + try: + with torch.no_grad(): + # output = facerestore_model(cropped_face_t, w=strength, adain=True)[0] + # output = facerestore_model(cropped_face_t)[0] + output = facerestore_model(cropped_face_t, w=codeformer_fidelity)[0] + restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) + del output + torch.cuda.empty_cache() + except Exception as error: + print(f"\tFailed inference for CodeFormer: {error}", file=sys.stderr) + restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) + + restored_face = restored_face.astype("uint8") + self.face_helper.add_restored_face(restored_face) + + self.face_helper.get_inverse_affine(None) + + restored_img = self.face_helper.paste_faces_to_input_image() + restored_img = restored_img[:, :, ::-1] + + if original_resolution != restored_img.shape[0:2]: + restored_img = cv2.resize( + restored_img, + (0, 0), + fx=original_resolution[1] / restored_img.shape[1], + fy=original_resolution[0] / restored_img.shape[0], + interpolation=cv2.INTER_LINEAR, + ) + + self.face_helper.clean_all() + + # restored_img = cv2.cvtColor(restored_face, cv2.COLOR_BGR2RGB) + + out_images[i] = restored_img + + restored_img_np = np.array(out_images).astype(np.float32) / 255.0 + restored_img_tensor = torch.from_numpy(restored_img_np) + return (restored_img_tensor,) + + +class CropFace: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "facedetection": (["retinaface_resnet50", "retinaface_mobile0.25", "YOLOv5l", "YOLOv5n"],), + }, + } + + RETURN_TYPES = ("IMAGE",) + + FUNCTION = "crop_face" + + CATEGORY = "facerestore_cf" + + def __init__(self): + self.face_helper = None + + def crop_face(self, image, facedetection): + device = model_management.get_torch_device() + if self.face_helper is None: + self.face_helper = FaceRestoreHelper( + 1, + face_size=512, + crop_ratio=(1, 1), + det_model=facedetection, + save_ext="png", + use_parse=True, + device=device, + ) + + image_np = 255.0 * image.cpu().numpy() + + total_images = image_np.shape[0] + out_images = np.ndarray(shape=(total_images, 512, 512, 3)) + next_idx = 0 + + for i in range(total_images): + + cur_image_np = image_np[i, :, :, ::-1] + + original_resolution = cur_image_np.shape[0:2] + + if self.face_helper is None: + return image + + self.face_helper.clean_all() + self.face_helper.read_image(cur_image_np) + self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) + self.face_helper.align_warp_face() + + faces_found = len(self.face_helper.cropped_faces) + if faces_found == 0: + next_idx += 1 # output black image for no face + if out_images.shape[0] < next_idx + faces_found: + print(out_images.shape) + print((next_idx + faces_found, 512, 512, 3)) + print("aaaaa") + out_images = np.resize(out_images, (next_idx + faces_found, 512, 512, 3)) + print(out_images.shape) + for j in range(faces_found): + cropped_face_1 = self.face_helper.cropped_faces[j] + cropped_face_2 = img2tensor(cropped_face_1 / 255.0, bgr2rgb=True, float32=True) + normalize(cropped_face_2, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + cropped_face_3 = cropped_face_2.unsqueeze(0).to(device) + cropped_face_4 = tensor2img(cropped_face_3, rgb2bgr=True, min_max=(-1, 1)).astype("uint8") + cropped_face_5 = cv2.cvtColor(cropped_face_4, cv2.COLOR_BGR2RGB) + out_images[next_idx] = cropped_face_5 + next_idx += 1 + + cropped_face_6 = np.array(out_images).astype(np.float32) / 255.0 + cropped_face_7 = torch.from_numpy(cropped_face_6) + return (cropped_face_7,) + + +class FaceRestoreModelLoader: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model_name": (folder_paths.get_filename_list("facerestore_models"),), + }, + } + + RETURN_TYPES = ("FACERESTORE_MODEL",) + FUNCTION = "load_model" + + CATEGORY = "facerestore_cf" + + # def load_model(self, model_name): + # model_path = folder_paths.get_full_path("facerestore_models", model_name) + # sd = comfy.utils.load_torch_file(model_path, safe_load=True) + # out = model_loading.load_state_dict(sd).eval() + # return (out, ) + + def load_model(self, model_name): + if "codeformer" in model_name.lower(): + print(f"\tLoading CodeFormer: {model_name}") + model_path = folder_paths.get_full_path("facerestore_models", model_name) + device = model_management.get_torch_device() + codeformer_net = ARCH_REGISTRY.get("CodeFormer")( + dim_embd=512, + codebook_size=1024, + n_head=8, + n_layers=9, + connect_list=["32", "64", "128", "256"], + ).to(device) + checkpoint = torch.load(model_path)["params_ema"] + codeformer_net.load_state_dict(checkpoint) + out = codeformer_net.eval() + return (out,) + else: + model_path = folder_paths.get_full_path("facerestore_models", model_name) + sd = comfy.utils.load_torch_file(model_path, safe_load=True) + out = model_loading.load_state_dict(sd).eval() + return (out,) + + +NODE_CLASS_MAPPINGS = { + "FaceRestoreCFWithModel": FaceRestoreCFWithModel, + "CropFace": CropFace, + "FaceRestoreModelLoader": FaceRestoreModelLoader, +} diff --git a/hordelib/nodes/facerestore_cf/basicsr/VERSION b/hordelib/nodes/facerestore_cf/basicsr/VERSION new file mode 100644 index 00000000..b85bccc7 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/VERSION @@ -0,0 +1 @@ +1.3.2 diff --git a/hordelib/nodes/facerestore_cf/basicsr/__init__.py b/hordelib/nodes/facerestore_cf/basicsr/__init__.py new file mode 100644 index 00000000..2a06af02 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/__init__.py @@ -0,0 +1,11 @@ +# https://github.com/xinntao/BasicSR +# flake8: noqa +from .archs import * +from .data import * +from .losses import * +from .metrics import * +from .models import * +from .ops import * +from .train import * +from .utils import * +from .version import __gitsha__, __version__ diff --git a/hordelib/nodes/facerestore_cf/basicsr/archs/__init__.py b/hordelib/nodes/facerestore_cf/basicsr/archs/__init__.py new file mode 100644 index 00000000..41b0cbc3 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/archs/__init__.py @@ -0,0 +1,27 @@ +import importlib +from copy import deepcopy +from os import path as osp + +from hordelib.nodes.facerestore_cf.basicsr.utils import get_root_logger, scandir +from hordelib.nodes.facerestore_cf.basicsr.utils.registry import ARCH_REGISTRY + +__all__ = ["build_network"] + +# automatically scan and import arch modules for registry +# scan all the files under the 'archs' folder and collect files ending with +# '_arch.py' +arch_folder = osp.dirname(osp.abspath(__file__)) +arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith("_arch.py")] +# import all the arch modules +_arch_modules = [ + importlib.import_module(f"hordelib.nodes.facerestore_cf.basicsr.archs.{file_name}") for file_name in arch_filenames +] + + +def build_network(opt): + opt = deepcopy(opt) + network_type = opt.pop("type") + net = ARCH_REGISTRY.get(network_type)(**opt) + logger = get_root_logger() + logger.info(f"Network [{net.__class__.__name__}] is created.") + return net diff --git a/hordelib/nodes/facerestore_cf/basicsr/archs/arcface_arch.py b/hordelib/nodes/facerestore_cf/basicsr/archs/arcface_arch.py new file mode 100644 index 00000000..b2627251 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/archs/arcface_arch.py @@ -0,0 +1,252 @@ +import torch.nn as nn + +from hordelib.nodes.facerestore_cf.basicsr.utils.registry import ARCH_REGISTRY + + +def conv3x3(inplanes, outplanes, stride=1): + """A simple wrapper for 3x3 convolution with padding. + + Args: + inplanes (int): Channel number of inputs. + outplanes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + """ + return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + """Basic residual block used in the ResNetArcFace architecture. + + Args: + inplanes (int): Channel number of inputs. + planes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + downsample (nn.Module): The downsample module. Default: None. + """ + + expansion = 1 # output channel expansion ratio + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class IRBlock(nn.Module): + """Improved residual block (IR Block) used in the ResNetArcFace architecture. + + Args: + inplanes (int): Channel number of inputs. + planes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + downsample (nn.Module): The downsample module. Default: None. + use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True. + """ + + expansion = 1 # output channel expansion ratio + + def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True): + super(IRBlock, self).__init__() + self.bn0 = nn.BatchNorm2d(inplanes) + self.conv1 = conv3x3(inplanes, inplanes) + self.bn1 = nn.BatchNorm2d(inplanes) + self.prelu = nn.PReLU() + self.conv2 = conv3x3(inplanes, planes, stride) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + self.use_se = use_se + if self.use_se: + self.se = SEBlock(planes) + + def forward(self, x): + residual = x + out = self.bn0(x) + out = self.conv1(out) + out = self.bn1(out) + out = self.prelu(out) + + out = self.conv2(out) + out = self.bn2(out) + if self.use_se: + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.prelu(out) + + return out + + +class Bottleneck(nn.Module): + """Bottleneck block used in the ResNetArcFace architecture. + + Args: + inplanes (int): Channel number of inputs. + planes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + downsample (nn.Module): The downsample module. Default: None. + """ + + expansion = 4 # output channel expansion ratio + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class SEBlock(nn.Module): + """The squeeze-and-excitation block (SEBlock) used in the IRBlock. + + Args: + channel (int): Channel number of inputs. + reduction (int): Channel reduction ration. Default: 16. + """ + + def __init__(self, channel, reduction=16): + super(SEBlock, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), + nn.PReLU(), + nn.Linear(channel // reduction, channel), + nn.Sigmoid(), + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y + + +@ARCH_REGISTRY.register() +class ResNetArcFace(nn.Module): + """ArcFace with ResNet architectures. + + Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition. + + Args: + block (str): Block used in the ArcFace architecture. + layers (tuple(int)): Block numbers in each layer. + use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True. + """ + + def __init__(self, block, layers, use_se=True): + if block == "IRBlock": + block = IRBlock + self.inplanes = 64 + self.use_se = use_se + super(ResNetArcFace, self).__init__() + + self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.prelu = nn.PReLU() + self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.bn4 = nn.BatchNorm2d(512) + self.dropout = nn.Dropout() + self.fc5 = nn.Linear(512 * 8 * 8, 512) + self.bn5 = nn.BatchNorm1d(512) + + # initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.xavier_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, num_blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se)) + self.inplanes = planes + for _ in range(1, num_blocks): + layers.append(block(self.inplanes, planes, use_se=self.use_se)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.bn4(x) + x = self.dropout(x) + x = x.view(x.size(0), -1) + x = self.fc5(x) + x = self.bn5(x) + + return x diff --git a/hordelib/nodes/facerestore_cf/basicsr/archs/arch_util.py b/hordelib/nodes/facerestore_cf/basicsr/archs/arch_util.py new file mode 100644 index 00000000..a40bf5f2 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/archs/arch_util.py @@ -0,0 +1,342 @@ +import collections.abc +import math +import warnings +from distutils.version import LooseVersion +from itertools import repeat + +import torch +import torchvision +from torch import nn as nn +from torch.nn import functional as F +from torch.nn import init as init +from torch.nn.modules.batchnorm import _BatchNorm + +from hordelib.nodes.facerestore_cf.basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv +from hordelib.nodes.facerestore_cf.basicsr.utils import get_root_logger + + +@torch.no_grad() +def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): + """Initialize network weights. + + Args: + module_list (list[nn.Module] | nn.Module): Modules to be initialized. + scale (float): Scale initialized weights, especially for residual + blocks. Default: 1. + bias_fill (float): The value to fill bias. Default: 0 + kwargs (dict): Other arguments for initialization function. + """ + if not isinstance(module_list, list): + module_list = [module_list] + for module in module_list: + for m in module.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, _BatchNorm): + init.constant_(m.weight, 1) + if m.bias is not None: + m.bias.data.fill_(bias_fill) + + +def make_layer(basic_block, num_basic_block, **kwarg): + """Make layers by stacking the same blocks. + + Args: + basic_block (nn.module): nn.module class for basic block. + num_basic_block (int): number of blocks. + + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_basic_block): + layers.append(basic_block(**kwarg)) + return nn.Sequential(*layers) + + +class ResidualBlockNoBN(nn.Module): + """Residual block without BN. + + It has a style of: + ---Conv-ReLU-Conv-+- + |________________| + + Args: + num_feat (int): Channel number of intermediate features. + Default: 64. + res_scale (float): Residual scale. Default: 1. + pytorch_init (bool): If set to True, use pytorch default init, + otherwise, use default_init_weights. Default: False. + """ + + def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): + super(ResidualBlockNoBN, self).__init__() + self.res_scale = res_scale + self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.relu = nn.ReLU(inplace=True) + + if not pytorch_init: + default_init_weights([self.conv1, self.conv2], 0.1) + + def forward(self, x): + identity = x + out = self.conv2(self.relu(self.conv1(x))) + return identity + out * self.res_scale + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f"scale {scale} is not supported. Supported scales: 2^n and 3.") + super(Upsample, self).__init__(*m) + + +def flow_warp(x, flow, interp_mode="bilinear", padding_mode="zeros", align_corners=True): + """Warp an image or feature map with optical flow. + + Args: + x (Tensor): Tensor with size (n, c, h, w). + flow (Tensor): Tensor with size (n, h, w, 2), normal value. + interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'. + padding_mode (str): 'zeros' or 'border' or 'reflection'. + Default: 'zeros'. + align_corners (bool): Before pytorch 1.3, the default value is + align_corners=True. After pytorch 1.3, the default value is + align_corners=False. Here, we use the True as default. + + Returns: + Tensor: Warped image or feature map. + """ + assert x.size()[-2:] == flow.size()[1:3] + _, _, h, w = x.size() + # create mesh grid + grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x)) + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + grid.requires_grad = False + + vgrid = grid + flow + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 + vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) + output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners) + + # TODO, what if align_corners=False + return output + + +def resize_flow(flow, size_type, sizes, interp_mode="bilinear", align_corners=False): + """Resize a flow according to ratio or shape. + + Args: + flow (Tensor): Precomputed flow. shape [N, 2, H, W]. + size_type (str): 'ratio' or 'shape'. + sizes (list[int | float]): the ratio for resizing or the final output + shape. + 1) The order of ratio should be [ratio_h, ratio_w]. For + downsampling, the ratio should be smaller than 1.0 (i.e., ratio + < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e., + ratio > 1.0). + 2) The order of output_size should be [out_h, out_w]. + interp_mode (str): The mode of interpolation for resizing. + Default: 'bilinear'. + align_corners (bool): Whether align corners. Default: False. + + Returns: + Tensor: Resized flow. + """ + _, _, flow_h, flow_w = flow.size() + if size_type == "ratio": + output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1]) + elif size_type == "shape": + output_h, output_w = sizes[0], sizes[1] + else: + raise ValueError(f"Size type should be ratio or shape, but got type {size_type}.") + + input_flow = flow.clone() + ratio_h = output_h / flow_h + ratio_w = output_w / flow_w + input_flow[:, 0, :, :] *= ratio_w + input_flow[:, 1, :, :] *= ratio_h + resized_flow = F.interpolate( + input=input_flow, + size=(output_h, output_w), + mode=interp_mode, + align_corners=align_corners, + ) + return resized_flow + + +# TODO: may write a cpp file +def pixel_unshuffle(x, scale): + """Pixel unshuffle. + + Args: + x (Tensor): Input feature with shape (b, c, hh, hw). + scale (int): Downsample ratio. + + Returns: + Tensor: the pixel unshuffled feature. + """ + b, c, hh, hw = x.size() + out_channel = c * (scale**2) + assert hh % scale == 0 and hw % scale == 0 + h = hh // scale + w = hw // scale + x_view = x.view(b, c, h, scale, w, scale) + return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) + + +class DCNv2Pack(ModulatedDeformConvPack): + """Modulated deformable conv for deformable alignment. + + Different from the official DCNv2Pack, which generates offsets and masks + from the preceding features, this DCNv2Pack takes another different + features to generate offsets and masks. + + Ref: + Delving Deep into Deformable Alignment in Video Super-Resolution. + """ + + def forward(self, x, feat): + out = self.conv_offset(feat) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + + offset_absmean = torch.mean(torch.abs(offset)) + if offset_absmean > 50: + logger = get_root_logger() + logger.warning(f"Offset abs mean is {offset_absmean}, larger than 50.") + + if LooseVersion(torchvision.__version__) >= LooseVersion("0.9.0"): + return torchvision.ops.deform_conv2d( + x, + offset, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + mask, + ) + else: + return modulated_deform_conv( + x, + offset, + mask, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + self.deformable_groups, + ) + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + low = norm_cdf((a - mean) / std) + up = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [low, up], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * low - 1, 2 * up - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. + + From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py + + The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +# From PyTorch +def _ntuple(n): + + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple diff --git a/hordelib/nodes/facerestore_cf/basicsr/archs/codeformer_arch.py b/hordelib/nodes/facerestore_cf/basicsr/archs/codeformer_arch.py new file mode 100644 index 00000000..e70ae4c0 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/archs/codeformer_arch.py @@ -0,0 +1,291 @@ +import math + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from hordelib.nodes.facerestore_cf.basicsr.archs.vqgan_arch import * +from hordelib.nodes.facerestore_cf.basicsr.utils.registry import ARCH_REGISTRY + + +def calc_mean_std(feat, eps=1e-5): + """Calculate mean and std for adaptive_instance_normalization. + + Args: + feat (Tensor): 4D tensor. + eps (float): A small value added to the variance to avoid + divide-by-zero. Default: 1e-5. + """ + size = feat.size() + assert len(size) == 4, "The input feature should be 4D tensor." + b, c = size[:2] + feat_var = feat.view(b, c, -1).var(dim=2) + eps + feat_std = feat_var.sqrt().view(b, c, 1, 1) + feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1) + return feat_mean, feat_std + + +def adaptive_instance_normalization(content_feat, style_feat): + """Adaptive instance normalization. + + Adjust the reference features to have the similar color and illuminations + as those in the degradate features. + + Args: + content_feat (Tensor): The reference feature. + style_feat (Tensor): The degradate features. + """ + size = content_feat.size() + style_mean, style_std = calc_mean_std(style_feat) + content_mean, content_std = calc_mean_std(content_feat) + normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) + return normalized_feat * style_std.expand(size) + style_mean.expand(size) + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x, mask=None): + if mask is None: + mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), + dim=4, + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), + dim=4, + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") + + +class TransformerSALayer(nn.Module): + def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"): + super().__init__() + self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout) + # Implementation of Feedforward model - MLP + self.linear1 = nn.Linear(embed_dim, dim_mlp) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_mlp, embed_dim) + + self.norm1 = nn.LayerNorm(embed_dim) + self.norm2 = nn.LayerNorm(embed_dim) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + def with_pos_embed(self, tensor, pos: Tensor | None): + return tensor if pos is None else tensor + pos + + def forward( + self, + tgt, + tgt_mask: Tensor | None = None, + tgt_key_padding_mask: Tensor | None = None, + query_pos: Tensor | None = None, + ): + + # self attention + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + + # ffn + tgt2 = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout2(tgt2) + return tgt + + +class Fuse_sft_block(nn.Module): + def __init__(self, in_ch, out_ch): + super().__init__() + self.encode_enc = ResBlock(2 * in_ch, out_ch) + + self.scale = nn.Sequential( + nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), + ) + + self.shift = nn.Sequential( + nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), + ) + + def forward(self, enc_feat, dec_feat, w=1): + enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1)) + scale = self.scale(enc_feat) + shift = self.shift(enc_feat) + residual = w * (dec_feat * scale + shift) + out = dec_feat + residual + return out + + +@ARCH_REGISTRY.register() +class CodeFormer(VQAutoEncoder): + def __init__( + self, + dim_embd=512, + n_head=8, + n_layers=9, + codebook_size=1024, + latent_size=256, + connect_list=["32", "64", "128", "256"], + fix_modules=["quantize", "generator"], + ): + super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], "nearest", 2, [16], codebook_size) + + if fix_modules is not None: + for module in fix_modules: + for param in getattr(self, module).parameters(): + param.requires_grad = False + + self.connect_list = connect_list + self.n_layers = n_layers + self.dim_embd = dim_embd + self.dim_mlp = dim_embd * 2 + + self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd)) + self.feat_emb = nn.Linear(256, self.dim_embd) + + # transformer + self.ft_layers = nn.Sequential( + *[ + TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0) + for _ in range(self.n_layers) + ], + ) + + # logits_predict head + self.idx_pred_layer = nn.Sequential(nn.LayerNorm(dim_embd), nn.Linear(dim_embd, codebook_size, bias=False)) + + self.channels = { + "16": 512, + "32": 256, + "64": 256, + "128": 128, + "256": 128, + "512": 64, + } + + # after second residual block for > 16, before attn layer for ==16 + self.fuse_encoder_block = {"512": 2, "256": 5, "128": 8, "64": 11, "32": 14, "16": 18} + # after first residual block for > 16, before attn layer for ==16 + self.fuse_generator_block = {"16": 6, "32": 9, "64": 12, "128": 15, "256": 18, "512": 21} + + # fuse_convs_dict + self.fuse_convs_dict = nn.ModuleDict() + for f_size in self.connect_list: + in_ch = self.channels[f_size] + self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch) + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def forward(self, x, w=0, detach_16=True, code_only=False, adain=False): + # ################### Encoder ##################### + enc_feat_dict = {} + out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list] + for i, block in enumerate(self.encoder.blocks): + x = block(x) + if i in out_list: + enc_feat_dict[str(x.shape[-1])] = x.clone() + + lq_feat = x + # ################# Transformer ################### + # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat) + pos_emb = self.position_emb.unsqueeze(1).repeat(1, x.shape[0], 1) + # BCHW -> BC(HW) -> (HW)BC + feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2, 0, 1)) + query_emb = feat_emb + # Transformer encoder + for layer in self.ft_layers: + query_emb = layer(query_emb, query_pos=pos_emb) + + # output logits + logits = self.idx_pred_layer(query_emb) # (hw)bn + logits = logits.permute(1, 0, 2) # (hw)bn -> b(hw)n + + if code_only: # for training stage II + # logits doesn't need softmax before cross_entropy loss + return logits, lq_feat + + # ################# Quantization ################### + # if self.training: + # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight]) + # # b(hw)c -> bc(hw) -> bchw + # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape) + # ------------ + soft_one_hot = F.softmax(logits, dim=2) + _, top_idx = torch.topk(soft_one_hot, 1, dim=2) + quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0], 16, 16, 256]) + # preserve gradients + # quant_feat = lq_feat + (quant_feat - lq_feat).detach() + + if detach_16: + quant_feat = quant_feat.detach() # for training stage III + if adain: + quant_feat = adaptive_instance_normalization(quant_feat, lq_feat) + + # ################## Generator #################### + x = quant_feat + fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list] + + for i, block in enumerate(self.generator.blocks): + x = block(x) + if i in fuse_list: # fuse after i-th block + f_size = str(x.shape[-1]) + if w > 0: + x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w) + out = x + # logits doesn't need softmax before cross_entropy loss + return out, logits, lq_feat diff --git a/hordelib/nodes/facerestore_cf/basicsr/archs/rrdbnet_arch.py b/hordelib/nodes/facerestore_cf/basicsr/archs/rrdbnet_arch.py new file mode 100644 index 00000000..bc0d00ab --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/archs/rrdbnet_arch.py @@ -0,0 +1,120 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F + +from hordelib.nodes.facerestore_cf.basicsr.utils.registry import ARCH_REGISTRY + +from .arch_util import default_init_weights, make_layer, pixel_unshuffle + + +class ResidualDenseBlock(nn.Module): + """Residual Dense Block. + + Used in RRDB block in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat=64, num_grow_ch=32): + super(ResidualDenseBlock, self).__init__() + self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) + self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + # Emperically, we use 0.2 to scale the residual for better performance + return x5 * 0.2 + x + + +class RRDB(nn.Module): + """Residual in Residual Dense Block. + + Used in RRDB-Net in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat, num_grow_ch=32): + super(RRDB, self).__init__() + self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) + + def forward(self, x): + out = self.rdb1(x) + out = self.rdb2(out) + out = self.rdb3(out) + # Emperically, we use 0.2 to scale the residual for better performance + return out * 0.2 + x + + +@ARCH_REGISTRY.register() +class RRDBNet(nn.Module): + """Networks consisting of Residual in Residual Dense Block, which is used + in ESRGAN. + + ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. + + We extend ESRGAN for scale x2 and scale x1. + Note: This is one option for scale 1, scale 2 in RRDBNet. + We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size + and enlarge the channel size before feeding inputs into the main ESRGAN architecture. + + Args: + num_in_ch (int): Channel number of inputs. + num_out_ch (int): Channel number of outputs. + num_feat (int): Channel number of intermediate features. + Default: 64 + num_block (int): Block number in the trunk network. Defaults: 23 + num_grow_ch (int): Channels for each growth. Default: 32. + """ + + def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32): + super(RRDBNet, self).__init__() + self.scale = scale + if scale == 2: + num_in_ch = num_in_ch * 4 + elif scale == 1: + num_in_ch = num_in_ch * 16 + self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) + self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + # upsample + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + if self.scale == 2: + feat = pixel_unshuffle(x, scale=2) + elif self.scale == 1: + feat = pixel_unshuffle(x, scale=4) + else: + feat = x + feat = self.conv_first(feat) + body_feat = self.conv_body(self.body(feat)) + feat = feat + body_feat + # upsample + feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode="nearest"))) + feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode="nearest"))) + out = self.conv_last(self.lrelu(self.conv_hr(feat))) + return out diff --git a/hordelib/nodes/facerestore_cf/basicsr/archs/vgg_arch.py b/hordelib/nodes/facerestore_cf/basicsr/archs/vgg_arch.py new file mode 100644 index 00000000..fc4e6e7f --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/archs/vgg_arch.py @@ -0,0 +1,264 @@ +import os +from collections import OrderedDict + +import torch +from torch import nn as nn +from torchvision.models import vgg as vgg + +from hordelib.nodes.facerestore_cf.basicsr.utils.registry import ARCH_REGISTRY + +VGG_PRETRAIN_PATH = "experiments/pretrained_models/vgg19-dcbb9e9d.pth" +NAMES = { + "vgg11": [ + "conv1_1", + "relu1_1", + "pool1", + "conv2_1", + "relu2_1", + "pool2", + "conv3_1", + "relu3_1", + "conv3_2", + "relu3_2", + "pool3", + "conv4_1", + "relu4_1", + "conv4_2", + "relu4_2", + "pool4", + "conv5_1", + "relu5_1", + "conv5_2", + "relu5_2", + "pool5", + ], + "vgg13": [ + "conv1_1", + "relu1_1", + "conv1_2", + "relu1_2", + "pool1", + "conv2_1", + "relu2_1", + "conv2_2", + "relu2_2", + "pool2", + "conv3_1", + "relu3_1", + "conv3_2", + "relu3_2", + "pool3", + "conv4_1", + "relu4_1", + "conv4_2", + "relu4_2", + "pool4", + "conv5_1", + "relu5_1", + "conv5_2", + "relu5_2", + "pool5", + ], + "vgg16": [ + "conv1_1", + "relu1_1", + "conv1_2", + "relu1_2", + "pool1", + "conv2_1", + "relu2_1", + "conv2_2", + "relu2_2", + "pool2", + "conv3_1", + "relu3_1", + "conv3_2", + "relu3_2", + "conv3_3", + "relu3_3", + "pool3", + "conv4_1", + "relu4_1", + "conv4_2", + "relu4_2", + "conv4_3", + "relu4_3", + "pool4", + "conv5_1", + "relu5_1", + "conv5_2", + "relu5_2", + "conv5_3", + "relu5_3", + "pool5", + ], + "vgg19": [ + "conv1_1", + "relu1_1", + "conv1_2", + "relu1_2", + "pool1", + "conv2_1", + "relu2_1", + "conv2_2", + "relu2_2", + "pool2", + "conv3_1", + "relu3_1", + "conv3_2", + "relu3_2", + "conv3_3", + "relu3_3", + "conv3_4", + "relu3_4", + "pool3", + "conv4_1", + "relu4_1", + "conv4_2", + "relu4_2", + "conv4_3", + "relu4_3", + "conv4_4", + "relu4_4", + "pool4", + "conv5_1", + "relu5_1", + "conv5_2", + "relu5_2", + "conv5_3", + "relu5_3", + "conv5_4", + "relu5_4", + "pool5", + ], +} + + +def insert_bn(names): + """Insert bn layer after each conv. + + Args: + names (list): The list of layer names. + + Returns: + list: The list of layer names with bn layers. + """ + names_bn = [] + for name in names: + names_bn.append(name) + if "conv" in name: + position = name.replace("conv", "") + names_bn.append("bn" + position) + return names_bn + + +@ARCH_REGISTRY.register() +class VGGFeatureExtractor(nn.Module): + """VGG network for feature extraction. + + In this implementation, we allow users to choose whether use normalization + in the input feature and the type of vgg network. Note that the pretrained + path must fit the vgg type. + + Args: + layer_name_list (list[str]): Forward function returns the corresponding + features according to the layer_name_list. + Example: {'relu1_1', 'relu2_1', 'relu3_1'}. + vgg_type (str): Set the type of vgg network. Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image. Importantly, + the input feature must in the range [0, 1]. Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. + requires_grad (bool): If true, the parameters of VGG network will be + optimized. Default: False. + remove_pooling (bool): If true, the max pooling operations in VGG net + will be removed. Default: False. + pooling_stride (int): The stride of max pooling operation. Default: 2. + """ + + def __init__( + self, + layer_name_list, + vgg_type="vgg19", + use_input_norm=True, + range_norm=False, + requires_grad=False, + remove_pooling=False, + pooling_stride=2, + ): + super(VGGFeatureExtractor, self).__init__() + + self.layer_name_list = layer_name_list + self.use_input_norm = use_input_norm + self.range_norm = range_norm + + self.names = NAMES[vgg_type.replace("_bn", "")] + if "bn" in vgg_type: + self.names = insert_bn(self.names) + + # only borrow layers that will be used to avoid unused params + max_idx = 0 + for v in layer_name_list: + idx = self.names.index(v) + if idx > max_idx: + max_idx = idx + + if os.path.exists(VGG_PRETRAIN_PATH): + vgg_net = getattr(vgg, vgg_type)(pretrained=False) + state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage) + vgg_net.load_state_dict(state_dict) + else: + vgg_net = getattr(vgg, vgg_type)(pretrained=True) + + features = vgg_net.features[: max_idx + 1] + + modified_net = OrderedDict() + for k, v in zip(self.names, features, strict=False): + if "pool" in k: + # if remove_pooling is true, pooling operation will be removed + if remove_pooling: + continue + else: + # in some cases, we may want to change the default stride + modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) + else: + modified_net[k] = v + + self.vgg_net = nn.Sequential(modified_net) + + if not requires_grad: + self.vgg_net.eval() + for param in self.parameters(): + param.requires_grad = False + else: + self.vgg_net.train() + for param in self.parameters(): + param.requires_grad = True + + if self.use_input_norm: + # the mean is for image with range [0, 1] + self.register_buffer("mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + # the std is for image with range [0, 1] + self.register_buffer("std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + if self.range_norm: + x = (x + 1) / 2 + if self.use_input_norm: + x = (x - self.mean) / self.std + output = {} + + for key, layer in self.vgg_net._modules.items(): + x = layer(x) + if key in self.layer_name_list: + output[key] = x.clone() + + return output diff --git a/hordelib/nodes/facerestore_cf/basicsr/archs/vqgan_arch.py b/hordelib/nodes/facerestore_cf/basicsr/archs/vqgan_arch.py new file mode 100644 index 00000000..58ca0d69 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/archs/vqgan_arch.py @@ -0,0 +1,456 @@ +""" +VQGAN code, adapted from the original created by the Unleashing Transformers authors: +https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py + +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from hordelib.nodes.facerestore_cf.basicsr.utils import get_root_logger +from hordelib.nodes.facerestore_cf.basicsr.utils.registry import ARCH_REGISTRY + + +def normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +@torch.jit.script +def swish(x): + return x * torch.sigmoid(x) + + +# Define VQVAE classes +class VectorQuantizer(nn.Module): + def __init__(self, codebook_size, emb_dim, beta): + super(VectorQuantizer, self).__init__() + self.codebook_size = codebook_size # number of embeddings + self.emb_dim = emb_dim # dimension of embedding + self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 + self.embedding = nn.Embedding(self.codebook_size, self.emb_dim) + self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.emb_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = ( + (z_flattened**2).sum(dim=1, keepdim=True) + + (self.embedding.weight**2).sum(1) + - 2 * torch.matmul(z_flattened, self.embedding.weight.t()) + ) + + mean_distance = torch.mean(d) + # find closest encodings + # min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) + min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False) + # [0-1], higher score, higher confidence + min_encoding_scores = torch.exp(-min_encoding_scores / 10) + + min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z) + min_encodings.scatter_(1, min_encoding_indices, 1) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) + # compute loss for embedding + loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) + # preserve gradients + z_q = z + (z_q - z).detach() + + # perplexity + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return ( + z_q, + loss, + { + "perplexity": perplexity, + "min_encodings": min_encodings, + "min_encoding_indices": min_encoding_indices, + "min_encoding_scores": min_encoding_scores, + "mean_distance": mean_distance, + }, + ) + + def get_codebook_feat(self, indices, shape): + # input indices: batch*token_num -> (batch*token_num)*1 + # shape: batch, height, width, channel + indices = indices.view(-1, 1) + min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices) + min_encodings.scatter_(1, indices, 1) + # get quantized latent vectors + z_q = torch.matmul(min_encodings.float(), self.embedding.weight) + + if shape is not None: # reshape back to match original input shape + z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous() + + return z_q + + +class GumbelQuantizer(nn.Module): + def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0): + super().__init__() + self.codebook_size = codebook_size # number of embeddings + self.emb_dim = emb_dim # dimension of embedding + self.straight_through = straight_through + self.temperature = temp_init + self.kl_weight = kl_weight + self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits + self.embed = nn.Embedding(codebook_size, emb_dim) + + def forward(self, z): + hard = self.straight_through if self.training else True + + logits = self.proj(z) + + soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard) + + z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight) + + # + kl divergence to the prior loss + qy = F.softmax(logits, dim=1) + diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean() + min_encoding_indices = soft_one_hot.argmax(dim=1) + + return ( + z_q, + diff, + { + "min_encoding_indices": min_encoding_indices, + }, + ) + + +class Downsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + + return x + + +class ResBlock(nn.Module): + def __init__(self, in_channels, out_channels=None): + super(ResBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.norm1 = normalize(in_channels) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = normalize(out_channels) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x_in): + x = x_in + x = self.norm1(x) + x = swish(x) + x = self.conv1(x) + x = self.norm2(x) + x = swish(x) + x = self.conv2(x) + if self.in_channels != self.out_channels: + x_in = self.conv_out(x_in) + + return x + x_in + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0, + ) + self.k = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0, + ) + self.v = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0, + ) + self.proj_out = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) + k = k.reshape(b, c, h * w) + w_ = torch.bmm(q, k) + w_ = w_ * (int(c) ** (-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) + h_ = torch.bmm(v, w_) + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +class Encoder(nn.Module): + def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions): + super().__init__() + self.nf = nf + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.attn_resolutions = attn_resolutions + + curr_res = self.resolution + in_ch_mult = (1,) + tuple(ch_mult) + + blocks = [] + # initial convultion + blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1)) + + # residual and downsampling blocks, with attention on smaller res (16x16) + for i in range(self.num_resolutions): + block_in_ch = nf * in_ch_mult[i] + block_out_ch = nf * ch_mult[i] + for _ in range(self.num_res_blocks): + blocks.append(ResBlock(block_in_ch, block_out_ch)) + block_in_ch = block_out_ch + if curr_res in attn_resolutions: + blocks.append(AttnBlock(block_in_ch)) + + if i != self.num_resolutions - 1: + blocks.append(Downsample(block_in_ch)) + curr_res = curr_res // 2 + + # non-local attention block + blocks.append(ResBlock(block_in_ch, block_in_ch)) + blocks.append(AttnBlock(block_in_ch)) + blocks.append(ResBlock(block_in_ch, block_in_ch)) + + # normalise and convert to latent size + blocks.append(normalize(block_in_ch)) + blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1)) + self.blocks = nn.ModuleList(blocks) + + def forward(self, x): + for block in self.blocks: + x = block(x) + + return x + + +class Generator(nn.Module): + def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions): + super().__init__() + self.nf = nf + self.ch_mult = ch_mult + self.num_resolutions = len(self.ch_mult) + self.num_res_blocks = res_blocks + self.resolution = img_size + self.attn_resolutions = attn_resolutions + self.in_channels = emb_dim + self.out_channels = 3 + block_in_ch = self.nf * self.ch_mult[-1] + curr_res = self.resolution // 2 ** (self.num_resolutions - 1) + + blocks = [] + # initial conv + blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)) + + # non-local attention block + blocks.append(ResBlock(block_in_ch, block_in_ch)) + blocks.append(AttnBlock(block_in_ch)) + blocks.append(ResBlock(block_in_ch, block_in_ch)) + + for i in reversed(range(self.num_resolutions)): + block_out_ch = self.nf * self.ch_mult[i] + + for _ in range(self.num_res_blocks): + blocks.append(ResBlock(block_in_ch, block_out_ch)) + block_in_ch = block_out_ch + + if curr_res in self.attn_resolutions: + blocks.append(AttnBlock(block_in_ch)) + + if i != 0: + blocks.append(Upsample(block_in_ch)) + curr_res = curr_res * 2 + + blocks.append(normalize(block_in_ch)) + blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1)) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x): + for block in self.blocks: + x = block(x) + + return x + + +@ARCH_REGISTRY.register() +class VQAutoEncoder(nn.Module): + def __init__( + self, + img_size, + nf, + ch_mult, + quantizer="nearest", + res_blocks=2, + attn_resolutions=[16], + codebook_size=1024, + emb_dim=256, + beta=0.25, + gumbel_straight_through=False, + gumbel_kl_weight=1e-8, + model_path=None, + ): + super().__init__() + logger = get_root_logger() + self.in_channels = 3 + self.nf = nf + self.n_blocks = res_blocks + self.codebook_size = codebook_size + self.embed_dim = emb_dim + self.ch_mult = ch_mult + self.resolution = img_size + self.attn_resolutions = attn_resolutions + self.quantizer_type = quantizer + self.encoder = Encoder( + self.in_channels, + self.nf, + self.embed_dim, + self.ch_mult, + self.n_blocks, + self.resolution, + self.attn_resolutions, + ) + if self.quantizer_type == "nearest": + self.beta = beta # 0.25 + self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta) + elif self.quantizer_type == "gumbel": + self.gumbel_num_hiddens = emb_dim + self.straight_through = gumbel_straight_through + self.kl_weight = gumbel_kl_weight + self.quantize = GumbelQuantizer( + self.codebook_size, + self.embed_dim, + self.gumbel_num_hiddens, + self.straight_through, + self.kl_weight, + ) + self.generator = Generator( + self.nf, + self.embed_dim, + self.ch_mult, + self.n_blocks, + self.resolution, + self.attn_resolutions, + ) + + if model_path is not None: + chkpt = torch.load(model_path, map_location="cpu") + if "params_ema" in chkpt: + self.load_state_dict(torch.load(model_path, map_location="cpu")["params_ema"]) + logger.info(f"vqgan is loaded from: {model_path} [params_ema]") + elif "params" in chkpt: + self.load_state_dict(torch.load(model_path, map_location="cpu")["params"]) + logger.info(f"vqgan is loaded from: {model_path} [params]") + else: + raise ValueError("Wrong params!") + + def forward(self, x): + x = self.encoder(x) + quant, codebook_loss, quant_stats = self.quantize(x) + x = self.generator(quant) + return x, codebook_loss, quant_stats + + +# patch based discriminator +@ARCH_REGISTRY.register() +class VQGANDiscriminator(nn.Module): + def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None): + super().__init__() + + layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)] + ndf_mult = 1 + ndf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + ndf_mult_prev = ndf_mult + ndf_mult = min(2**n, 8) + layers += [ + nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(ndf * ndf_mult), + nn.LeakyReLU(0.2, True), + ] + + ndf_mult_prev = ndf_mult + ndf_mult = min(2**n_layers, 8) + + layers += [ + nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False), + nn.BatchNorm2d(ndf * ndf_mult), + nn.LeakyReLU(0.2, True), + ] + + layers += [nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map + self.main = nn.Sequential(*layers) + + if model_path is not None: + chkpt = torch.load(model_path, map_location="cpu") + if "params_d" in chkpt: + self.load_state_dict(torch.load(model_path, map_location="cpu")["params_d"]) + elif "params" in chkpt: + self.load_state_dict(torch.load(model_path, map_location="cpu")["params"]) + else: + raise ValueError("Wrong params!") + + def forward(self, x): + return self.main(x) diff --git a/hordelib/nodes/facerestore_cf/basicsr/data/__init__.py b/hordelib/nodes/facerestore_cf/basicsr/data/__init__.py new file mode 100644 index 00000000..2dc9effe --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/data/__init__.py @@ -0,0 +1,103 @@ +import importlib +import random +from copy import deepcopy +from functools import partial +from os import path as osp + +import numpy as np +import torch +import torch.utils.data + +from hordelib.nodes.facerestore_cf.basicsr.data.prefetch_dataloader import PrefetchDataLoader +from hordelib.nodes.facerestore_cf.basicsr.utils import get_root_logger, scandir +from hordelib.nodes.facerestore_cf.basicsr.utils.dist_util import get_dist_info +from hordelib.nodes.facerestore_cf.basicsr.utils.registry import DATASET_REGISTRY + +__all__ = ["build_dataset", "build_dataloader"] + +# automatically scan and import dataset modules for registry +# scan all the files under the data folder with '_dataset' in file names +data_folder = osp.dirname(osp.abspath(__file__)) +dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith("_dataset.py")] +# import all the dataset modules +_dataset_modules = [importlib.import_module(f"basicsr.data.{file_name}") for file_name in dataset_filenames] + + +def build_dataset(dataset_opt): + """Build dataset from options. + + Args: + dataset_opt (dict): Configuration for dataset. It must constain: + name (str): Dataset name. + type (str): Dataset type. + """ + dataset_opt = deepcopy(dataset_opt) + dataset = DATASET_REGISTRY.get(dataset_opt["type"])(dataset_opt) + logger = get_root_logger() + logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' "is built.") + return dataset + + +def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None): + """Build dataloader. + + Args: + dataset (torch.utils.data.Dataset): Dataset. + dataset_opt (dict): Dataset options. It contains the following keys: + phase (str): 'train' or 'val'. + num_worker_per_gpu (int): Number of workers for each GPU. + batch_size_per_gpu (int): Training batch size for each GPU. + num_gpu (int): Number of GPUs. Used only in the train phase. + Default: 1. + dist (bool): Whether in distributed training. Used only in the train + phase. Default: False. + sampler (torch.utils.data.sampler): Data sampler. Default: None. + seed (int | None): Seed. Default: None + """ + phase = dataset_opt["phase"] + rank, _ = get_dist_info() + if phase == "train": + if dist: # distributed training + batch_size = dataset_opt["batch_size_per_gpu"] + num_workers = dataset_opt["num_worker_per_gpu"] + else: # non-distributed training + multiplier = 1 if num_gpu == 0 else num_gpu + batch_size = dataset_opt["batch_size_per_gpu"] * multiplier + num_workers = dataset_opt["num_worker_per_gpu"] * multiplier + dataloader_args = dict( + dataset=dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + sampler=sampler, + drop_last=True, + ) + if sampler is None: + dataloader_args["shuffle"] = True + dataloader_args["worker_init_fn"] = ( + partial(worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None + ) + elif phase in ["val", "test"]: # validation + dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) + else: + raise ValueError(f"Wrong dataset phase: {phase}. " "Supported ones are 'train', 'val' and 'test'.") + + dataloader_args["pin_memory"] = dataset_opt.get("pin_memory", False) + + prefetch_mode = dataset_opt.get("prefetch_mode") + if prefetch_mode == "cpu": # CPUPrefetcher + num_prefetch_queue = dataset_opt.get("num_prefetch_queue", 1) + logger = get_root_logger() + logger.info(f"Use {prefetch_mode} prefetch dataloader: " f"num_prefetch_queue = {num_prefetch_queue}") + return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args) + else: + # prefetch_mode=None: Normal dataloader + # prefetch_mode='cuda': dataloader for CUDAPrefetcher + return torch.utils.data.DataLoader(**dataloader_args) + + +def worker_init_fn(worker_id, num_workers, rank, seed): + # Set the worker seed to num_workers * rank + worker_id + seed + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) diff --git a/hordelib/nodes/facerestore_cf/basicsr/data/data_sampler.py b/hordelib/nodes/facerestore_cf/basicsr/data/data_sampler.py new file mode 100644 index 00000000..e4f2ccf8 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/data/data_sampler.py @@ -0,0 +1,49 @@ +import math + +import torch +from torch.utils.data.sampler import Sampler + + +class EnlargedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + + Modified from torch.utils.data.distributed.DistributedSampler + Support enlarging the dataset for iteration-based training, for saving + time when restart the dataloader after each epoch + + Args: + dataset (torch.utils.data.Dataset): Dataset used for sampling. + num_replicas (int | None): Number of processes participating in + the training. It is usually the world_size. + rank (int | None): Rank of the current process within num_replicas. + ratio (int): Enlarging ratio. Default: 1. + """ + + def __init__(self, dataset, num_replicas, rank, ratio=1): + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(self.total_size, generator=g).tolist() + + dataset_size = len(self.dataset) + indices = [v % dataset_size for v in indices] + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/hordelib/nodes/facerestore_cf/basicsr/data/data_util.py b/hordelib/nodes/facerestore_cf/basicsr/data/data_util.py new file mode 100644 index 00000000..b2481c00 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/data/data_util.py @@ -0,0 +1,313 @@ +from os import path as osp + +import cv2 +import numpy as np +import torch +from torch.nn import functional as F + +from hordelib.nodes.facerestore_cf.basicsr.data.transforms import mod_crop +from hordelib.nodes.facerestore_cf.basicsr.utils import img2tensor, scandir + + +def read_img_seq(path, require_mod_crop=False, scale=1): + """Read a sequence of images from a given folder path. + + Args: + path (list[str] | str): List of image paths or image folder path. + require_mod_crop (bool): Require mod crop for each image. + Default: False. + scale (int): Scale factor for mod_crop. Default: 1. + + Returns: + Tensor: size (t, c, h, w), RGB, [0, 1]. + """ + if isinstance(path, list): + img_paths = path + else: + img_paths = sorted(list(scandir(path, full_path=True))) + imgs = [cv2.imread(v).astype(np.float32) / 255.0 for v in img_paths] + if require_mod_crop: + imgs = [mod_crop(img, scale) for img in imgs] + imgs = img2tensor(imgs, bgr2rgb=True, float32=True) + imgs = torch.stack(imgs, dim=0) + return imgs + + +def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding="reflection"): + """Generate an index list for reading `num_frames` frames from a sequence + of images. + + Args: + crt_idx (int): Current center index. + max_frame_num (int): Max number of the sequence of images (from 1). + num_frames (int): Reading num_frames frames. + padding (str): Padding mode, one of + 'replicate' | 'reflection' | 'reflection_circle' | 'circle' + Examples: current_idx = 0, num_frames = 5 + The generated frame indices under different padding mode: + replicate: [0, 0, 0, 1, 2] + reflection: [2, 1, 0, 1, 2] + reflection_circle: [4, 3, 0, 1, 2] + circle: [3, 4, 0, 1, 2] + + Returns: + list[int]: A list of indices. + """ + assert num_frames % 2 == 1, "num_frames should be an odd number." + assert padding in ("replicate", "reflection", "reflection_circle", "circle"), f"Wrong padding mode: {padding}." + + max_frame_num = max_frame_num - 1 # start from 0 + num_pad = num_frames // 2 + + indices = [] + for i in range(crt_idx - num_pad, crt_idx + num_pad + 1): + if i < 0: + if padding == "replicate": + pad_idx = 0 + elif padding == "reflection": + pad_idx = -i + elif padding == "reflection_circle": + pad_idx = crt_idx + num_pad - i + else: + pad_idx = num_frames + i + elif i > max_frame_num: + if padding == "replicate": + pad_idx = max_frame_num + elif padding == "reflection": + pad_idx = max_frame_num * 2 - i + elif padding == "reflection_circle": + pad_idx = (crt_idx - num_pad) - (i - max_frame_num) + else: + pad_idx = i - num_frames + else: + pad_idx = i + indices.append(pad_idx) + return indices + + +def paired_paths_from_lmdb(folders, keys): + """Generate paired paths from lmdb files. + + Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is: + + lq.lmdb + ├── data.mdb + ├── lock.mdb + ├── meta_info.txt + + The data.mdb and lock.mdb are standard lmdb files and you can refer to + https://lmdb.readthedocs.io/en/release/ for more details. + + The meta_info.txt is a specified txt file to record the meta information + of our datasets. It will be automatically created when preparing + datasets by our provided dataset tools. + Each line in the txt file records + 1)image name (with extension), + 2)image shape, + 3)compression level, separated by a white space. + Example: `baboon.png (120,125,3) 1` + + We use the image name without extension as the lmdb key. + Note that we use the same key for the corresponding lq and gt images. + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + Note that this key is different from lmdb keys. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ( + "The len of folders should be 2 with [input_folder, gt_folder]. " f"But got {len(folders)}" + ) + assert len(keys) == 2, "The len of keys should be 2 with [input_key, gt_key]. " f"But got {len(keys)}" + input_folder, gt_folder = folders + input_key, gt_key = keys + + if not (input_folder.endswith(".lmdb") and gt_folder.endswith(".lmdb")): + raise ValueError( + f"{input_key} folder and {gt_key} folder should both in lmdb " + f"formats. But received {input_key}: {input_folder}; " + f"{gt_key}: {gt_folder}", + ) + # ensure that the two meta_info files are the same + with open(osp.join(input_folder, "meta_info.txt")) as fin: + input_lmdb_keys = [line.split(".")[0] for line in fin] + with open(osp.join(gt_folder, "meta_info.txt")) as fin: + gt_lmdb_keys = [line.split(".")[0] for line in fin] + if set(input_lmdb_keys) != set(gt_lmdb_keys): + raise ValueError(f"Keys in {input_key}_folder and {gt_key}_folder are different.") + else: + paths = [] + for lmdb_key in sorted(input_lmdb_keys): + paths.append(dict([(f"{input_key}_path", lmdb_key), (f"{gt_key}_path", lmdb_key)])) + return paths + + +def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl): + """Generate paired paths from an meta information file. + + Each line in the meta information file contains the image names and + image shape (usually for gt), separated by a white space. + + Example of an meta information file: + ``` + 0001_s001.png (480,480,3) + 0001_s002.png (480,480,3) + ``` + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + meta_info_file (str): Path to the meta information file. + filename_tmpl (str): Template for each filename. Note that the + template excludes the file extension. Usually the filename_tmpl is + for files in the input folder. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ( + "The len of folders should be 2 with [input_folder, gt_folder]. " f"But got {len(folders)}" + ) + assert len(keys) == 2, "The len of keys should be 2 with [input_key, gt_key]. " f"But got {len(keys)}" + input_folder, gt_folder = folders + input_key, gt_key = keys + + with open(meta_info_file) as fin: + gt_names = [line.split(" ")[0] for line in fin] + + paths = [] + for gt_name in gt_names: + basename, ext = osp.splitext(osp.basename(gt_name)) + input_name = f"{filename_tmpl.format(basename)}{ext}" + input_path = osp.join(input_folder, input_name) + gt_path = osp.join(gt_folder, gt_name) + paths.append(dict([(f"{input_key}_path", input_path), (f"{gt_key}_path", gt_path)])) + return paths + + +def paired_paths_from_folder(folders, keys, filename_tmpl): + """Generate paired paths from folders. + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + filename_tmpl (str): Template for each filename. Note that the + template excludes the file extension. Usually the filename_tmpl is + for files in the input folder. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ( + "The len of folders should be 2 with [input_folder, gt_folder]. " f"But got {len(folders)}" + ) + assert len(keys) == 2, "The len of keys should be 2 with [input_key, gt_key]. " f"But got {len(keys)}" + input_folder, gt_folder = folders + input_key, gt_key = keys + + input_paths = list(scandir(input_folder)) + gt_paths = list(scandir(gt_folder)) + assert len(input_paths) == len(gt_paths), ( + f"{input_key} and {gt_key} datasets have different number of images: " f"{len(input_paths)}, {len(gt_paths)}." + ) + paths = [] + for gt_path in gt_paths: + basename, ext = osp.splitext(osp.basename(gt_path)) + input_name = f"{filename_tmpl.format(basename)}{ext}" + input_path = osp.join(input_folder, input_name) + assert input_name in input_paths, f"{input_name} is not in " f"{input_key}_paths." + gt_path = osp.join(gt_folder, gt_path) + paths.append(dict([(f"{input_key}_path", input_path), (f"{gt_key}_path", gt_path)])) + return paths + + +def paths_from_folder(folder): + """Generate paths from folder. + + Args: + folder (str): Folder path. + + Returns: + list[str]: Returned path list. + """ + + paths = list(scandir(folder)) + paths = [osp.join(folder, path) for path in paths] + return paths + + +def paths_from_lmdb(folder): + """Generate paths from lmdb. + + Args: + folder (str): Folder path. + + Returns: + list[str]: Returned path list. + """ + if not folder.endswith(".lmdb"): + raise ValueError(f"Folder {folder}folder should in lmdb format.") + with open(osp.join(folder, "meta_info.txt")) as fin: + paths = [line.split(".")[0] for line in fin] + return paths + + +def generate_gaussian_kernel(kernel_size=13, sigma=1.6): + """Generate Gaussian kernel used in `duf_downsample`. + + Args: + kernel_size (int): Kernel size. Default: 13. + sigma (float): Sigma of the Gaussian kernel. Default: 1.6. + + Returns: + np.array: The Gaussian kernel. + """ + from scipy.ndimage import filters as filters + + kernel = np.zeros((kernel_size, kernel_size)) + # set element at the middle to one, a dirac delta + kernel[kernel_size // 2, kernel_size // 2] = 1 + # gaussian-smooth the dirac, resulting in a gaussian filter + return filters.gaussian_filter(kernel, sigma) + + +def duf_downsample(x, kernel_size=13, scale=4): + """Downsamping with Gaussian kernel used in the DUF official code. + + Args: + x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w). + kernel_size (int): Kernel size. Default: 13. + scale (int): Downsampling factor. Supported scale: (2, 3, 4). + Default: 4. + + Returns: + Tensor: DUF downsampled frames. + """ + assert scale in (2, 3, 4), f"Only support scale (2, 3, 4), but got {scale}." + + squeeze_flag = False + if x.ndim == 4: + squeeze_flag = True + x = x.unsqueeze(0) + b, t, c, h, w = x.size() + x = x.view(-1, 1, h, w) + pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2 + x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), "reflect") + + gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale) + gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0) + x = F.conv2d(x, gaussian_filter, stride=scale) + x = x[:, :, 2:-2, 2:-2] + x = x.view(b, t, c, x.size(2), x.size(3)) + if squeeze_flag: + x = x.squeeze(0) + return x diff --git a/hordelib/nodes/facerestore_cf/basicsr/data/prefetch_dataloader.py b/hordelib/nodes/facerestore_cf/basicsr/data/prefetch_dataloader.py new file mode 100644 index 00000000..dd84bb15 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/data/prefetch_dataloader.py @@ -0,0 +1,126 @@ +import queue as Queue +import threading + +import torch +from torch.utils.data import DataLoader + + +class PrefetchGenerator(threading.Thread): + """A general prefetch generator. + + Ref: + https://stackoverflow.com/questions/7323664/python-generator-pre-fetch + + Args: + generator: Python generator. + num_prefetch_queue (int): Number of prefetch queue. + """ + + def __init__(self, generator, num_prefetch_queue): + threading.Thread.__init__(self) + self.queue = Queue.Queue(num_prefetch_queue) + self.generator = generator + self.daemon = True + self.start() + + def run(self): + for item in self.generator: + self.queue.put(item) + self.queue.put(None) + + def __next__(self): + next_item = self.queue.get() + if next_item is None: + raise StopIteration + return next_item + + def __iter__(self): + return self + + +class PrefetchDataLoader(DataLoader): + """Prefetch version of dataloader. + + Ref: + https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# + + TODO: + Need to test on single gpu and ddp (multi-gpu). There is a known issue in + ddp. + + Args: + num_prefetch_queue (int): Number of prefetch queue. + kwargs (dict): Other arguments for dataloader. + """ + + def __init__(self, num_prefetch_queue, **kwargs): + self.num_prefetch_queue = num_prefetch_queue + super(PrefetchDataLoader, self).__init__(**kwargs) + + def __iter__(self): + return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) + + +class CPUPrefetcher: + """CPU prefetcher. + + Args: + loader: Dataloader. + """ + + def __init__(self, loader): + self.ori_loader = loader + self.loader = iter(loader) + + def next(self): + try: + return next(self.loader) + except StopIteration: + return None + + def reset(self): + self.loader = iter(self.ori_loader) + + +class CUDAPrefetcher: + """CUDA prefetcher. + + Ref: + https://github.com/NVIDIA/apex/issues/304# + + It may consums more GPU memory. + + Args: + loader: Dataloader. + opt (dict): Options. + """ + + def __init__(self, loader, opt): + self.ori_loader = loader + self.loader = iter(loader) + self.opt = opt + self.stream = torch.cuda.Stream() + self.device = torch.device("cuda" if opt["num_gpu"] != 0 else "cpu") + self.preload() + + def preload(self): + try: + self.batch = next(self.loader) # self.batch is a dict + except StopIteration: + self.batch = None + return + # put tensors to gpu + with torch.cuda.stream(self.stream): + for k, v in self.batch.items(): + if torch.is_tensor(v): + self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) + + def next(self): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.batch + self.preload() + return batch + + def reset(self): + self.loader = iter(self.ori_loader) + self.preload() diff --git a/hordelib/nodes/facerestore_cf/basicsr/data/transforms.py b/hordelib/nodes/facerestore_cf/basicsr/data/transforms.py new file mode 100644 index 00000000..ac562bff --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/data/transforms.py @@ -0,0 +1,170 @@ +import random + +import cv2 + + +def mod_crop(img, scale): + """Mod crop images, used during testing. + + Args: + img (ndarray): Input image. + scale (int): Scale factor. + + Returns: + ndarray: Result image. + """ + img = img.copy() + if img.ndim in (2, 3): + h, w = img.shape[0], img.shape[1] + h_remainder, w_remainder = h % scale, w % scale + img = img[: h - h_remainder, : w - w_remainder, ...] + else: + raise ValueError(f"Wrong img ndim: {img.ndim}.") + return img + + +def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path): + """Paired random crop. + + It crops lists of lq and gt images with corresponding locations. + + Args: + img_gts (list[ndarray] | ndarray): GT images. Note that all images + should have the same shape. If the input is an ndarray, it will + be transformed to a list containing itself. + img_lqs (list[ndarray] | ndarray): LQ images. Note that all images + should have the same shape. If the input is an ndarray, it will + be transformed to a list containing itself. + gt_patch_size (int): GT patch size. + scale (int): Scale factor. + gt_path (str): Path to ground-truth. + + Returns: + list[ndarray] | ndarray: GT images and LQ images. If returned results + only have one element, just return ndarray. + """ + + if not isinstance(img_gts, list): + img_gts = [img_gts] + if not isinstance(img_lqs, list): + img_lqs = [img_lqs] + + h_lq, w_lq, _ = img_lqs[0].shape + h_gt, w_gt, _ = img_gts[0].shape + lq_patch_size = gt_patch_size // scale + + if h_gt != h_lq * scale or w_gt != w_lq * scale: + raise ValueError( + f"Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ", + f"multiplication of LQ ({h_lq}, {w_lq}).", + ) + if h_lq < lq_patch_size or w_lq < lq_patch_size: + raise ValueError( + f"LQ ({h_lq}, {w_lq}) is smaller than patch size " + f"({lq_patch_size}, {lq_patch_size}). " + f"Please remove {gt_path}.", + ) + + # randomly choose top and left coordinates for lq patch + top = random.randint(0, h_lq - lq_patch_size) + left = random.randint(0, w_lq - lq_patch_size) + + # crop lq patch + img_lqs = [v[top : top + lq_patch_size, left : left + lq_patch_size, ...] for v in img_lqs] + + # crop corresponding gt patch + top_gt, left_gt = int(top * scale), int(left * scale) + img_gts = [v[top_gt : top_gt + gt_patch_size, left_gt : left_gt + gt_patch_size, ...] for v in img_gts] + if len(img_gts) == 1: + img_gts = img_gts[0] + if len(img_lqs) == 1: + img_lqs = img_lqs[0] + return img_gts, img_lqs + + +def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): + """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). + + We use vertical flip and transpose for rotation implementation. + All the images in the list use the same augmentation. + + Args: + imgs (list[ndarray] | ndarray): Images to be augmented. If the input + is an ndarray, it will be transformed to a list. + hflip (bool): Horizontal flip. Default: True. + rotation (bool): Ratotation. Default: True. + flows (list[ndarray]: Flows to be augmented. If the input is an + ndarray, it will be transformed to a list. + Dimension is (h, w, 2). Default: None. + return_status (bool): Return the status of flip and rotation. + Default: False. + + Returns: + list[ndarray] | ndarray: Augmented images and flows. If returned + results only have one element, just return ndarray. + + """ + hflip = hflip and random.random() < 0.5 + vflip = rotation and random.random() < 0.5 + rot90 = rotation and random.random() < 0.5 + + def _augment(img): + if hflip: # horizontal + cv2.flip(img, 1, img) + if vflip: # vertical + cv2.flip(img, 0, img) + if rot90: + img = img.transpose(1, 0, 2) + return img + + def _augment_flow(flow): + if hflip: # horizontal + cv2.flip(flow, 1, flow) + flow[:, :, 0] *= -1 + if vflip: # vertical + cv2.flip(flow, 0, flow) + flow[:, :, 1] *= -1 + if rot90: + flow = flow.transpose(1, 0, 2) + flow = flow[:, :, [1, 0]] + return flow + + if not isinstance(imgs, list): + imgs = [imgs] + imgs = [_augment(img) for img in imgs] + if len(imgs) == 1: + imgs = imgs[0] + + if flows is not None: + if not isinstance(flows, list): + flows = [flows] + flows = [_augment_flow(flow) for flow in flows] + if len(flows) == 1: + flows = flows[0] + return imgs, flows + else: + if return_status: + return imgs, (hflip, vflip, rot90) + else: + return imgs + + +def img_rotate(img, angle, center=None, scale=1.0): + """Rotate image. + + Args: + img (ndarray): Image to be rotated. + angle (float): Rotation angle in degrees. Positive values mean + counter-clockwise rotation. + center (tuple[int]): Rotation center. If the center is None, + initialize it as the center of the image. Default: None. + scale (float): Isotropic scale factor. Default: 1.0. + """ + (h, w) = img.shape[:2] + + if center is None: + center = (w // 2, h // 2) + + matrix = cv2.getRotationMatrix2D(center, angle, scale) + rotated_img = cv2.warpAffine(img, matrix, (w, h)) + return rotated_img diff --git a/hordelib/nodes/facerestore_cf/basicsr/losses/__init__.py b/hordelib/nodes/facerestore_cf/basicsr/losses/__init__.py new file mode 100644 index 00000000..fb45a76e --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/losses/__init__.py @@ -0,0 +1,43 @@ +from copy import deepcopy + +from hordelib.nodes.facerestore_cf.basicsr.utils import get_root_logger +from hordelib.nodes.facerestore_cf.basicsr.utils.registry import LOSS_REGISTRY + +from .losses import ( + CharbonnierLoss, + GANLoss, + L1Loss, + MSELoss, + PerceptualLoss, + WeightedTVLoss, + g_path_regularize, + gradient_penalty_loss, + r1_penalty, +) + +__all__ = [ + "L1Loss", + "MSELoss", + "CharbonnierLoss", + "WeightedTVLoss", + "PerceptualLoss", + "GANLoss", + "gradient_penalty_loss", + "r1_penalty", + "g_path_regularize", +] + + +def build_loss(opt): + """Build loss from options. + + Args: + opt (dict): Configuration. It must constain: + type (str): Model type. + """ + opt = deepcopy(opt) + loss_type = opt.pop("type") + loss = LOSS_REGISTRY.get(loss_type)(**opt) + logger = get_root_logger() + logger.info(f"Loss [{loss.__class__.__name__}] is created.") + return loss diff --git a/hordelib/nodes/facerestore_cf/basicsr/losses/loss_util.py b/hordelib/nodes/facerestore_cf/basicsr/losses/loss_util.py new file mode 100644 index 00000000..b08bcaee --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/losses/loss_util.py @@ -0,0 +1,96 @@ +import functools + +from torch.nn import functional as F + + +def reduce_loss(loss, reduction): + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are 'none', 'mean' and 'sum'. + + Returns: + Tensor: Reduced loss tensor. + """ + reduction_enum = F._Reduction.get_enum(reduction) + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + else: + return loss.sum() + + +def weight_reduce_loss(loss, weight=None, reduction="mean"): + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Tensor): Element-wise weights. Default: None. + reduction (str): Same as built-in losses of PyTorch. Options are + 'none', 'mean' and 'sum'. Default: 'mean'. + + Returns: + Tensor: Loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + assert weight.dim() == loss.dim() + assert weight.size(1) == 1 or weight.size(1) == loss.size(1) + loss = loss * weight + + # if weight is not specified or reduction is sum, just reduce the loss + if weight is None or reduction == "sum": + loss = reduce_loss(loss, reduction) + # if reduction is mean, then compute mean over weight region + elif reduction == "mean": + if weight.size(1) > 1: + weight = weight.sum() + else: + weight = weight.sum() * loss.size(1) + loss = loss.sum() / weight + + return loss + + +def weighted_loss(loss_func): + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + `loss_func(pred, target, **kwargs)`. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like `loss_func(pred, target, weight=None, reduction='mean', + **kwargs)`. + + :Example: + + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.5000) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, reduction='sum') + tensor(3.) + """ + + @functools.wraps(loss_func) + def wrapper(pred, target, weight=None, reduction="mean", **kwargs): + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction) + return loss + + return wrapper diff --git a/hordelib/nodes/facerestore_cf/basicsr/losses/losses.py b/hordelib/nodes/facerestore_cf/basicsr/losses/losses.py new file mode 100644 index 00000000..a6c9cd8d --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/losses/losses.py @@ -0,0 +1,470 @@ +import math + +import lpips +import torch +from torch import autograd as autograd +from torch import nn as nn +from torch.nn import functional as F + +from hordelib.nodes.facerestore_cf.basicsr.archs.vgg_arch import VGGFeatureExtractor +from hordelib.nodes.facerestore_cf.basicsr.utils.registry import LOSS_REGISTRY + +from .loss_util import weighted_loss + +_reduction_modes = ["none", "mean", "sum"] + + +@weighted_loss +def l1_loss(pred, target): + return F.l1_loss(pred, target, reduction="none") + + +@weighted_loss +def mse_loss(pred, target): + return F.mse_loss(pred, target, reduction="none") + + +@weighted_loss +def charbonnier_loss(pred, target, eps=1e-12): + return torch.sqrt((pred - target) ** 2 + eps) + + +@LOSS_REGISTRY.register() +class L1Loss(nn.Module): + """L1 (mean absolute error, MAE) loss. + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction="mean"): + super(L1Loss, self).__init__() + if reduction not in ["none", "mean", "sum"]: + raise ValueError(f"Unsupported reduction mode: {reduction}. " f"Supported ones are: {_reduction_modes}") + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise + weights. Default: None. + """ + return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class MSELoss(nn.Module): + """MSE (L2) loss. + + Args: + loss_weight (float): Loss weight for MSE loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction="mean"): + super(MSELoss, self).__init__() + if reduction not in ["none", "mean", "sum"]: + raise ValueError(f"Unsupported reduction mode: {reduction}. " f"Supported ones are: {_reduction_modes}") + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise + weights. Default: None. + """ + return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class CharbonnierLoss(nn.Module): + """Charbonnier loss (one variant of Robust L1Loss, a differentiable + variant of L1Loss). + + Described in "Deep Laplacian Pyramid Networks for Fast and Accurate + Super-Resolution". + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + eps (float): A value used to control the curvature near zero. + Default: 1e-12. + """ + + def __init__(self, loss_weight=1.0, reduction="mean", eps=1e-12): + super(CharbonnierLoss, self).__init__() + if reduction not in ["none", "mean", "sum"]: + raise ValueError(f"Unsupported reduction mode: {reduction}. " f"Supported ones are: {_reduction_modes}") + + self.loss_weight = loss_weight + self.reduction = reduction + self.eps = eps + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise + weights. Default: None. + """ + return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class WeightedTVLoss(L1Loss): + """Weighted TV loss. + + Args: + loss_weight (float): Loss weight. Default: 1.0. + """ + + def __init__(self, loss_weight=1.0): + super(WeightedTVLoss, self).__init__(loss_weight=loss_weight) + + def forward(self, pred, weight=None): + y_diff = super(WeightedTVLoss, self).forward( + pred[:, :, :-1, :], + pred[:, :, 1:, :], + weight=weight[:, :, :-1, :], + ) + x_diff = super(WeightedTVLoss, self).forward( + pred[:, :, :, :-1], + pred[:, :, :, 1:], + weight=weight[:, :, :, :-1], + ) + + loss = x_diff + y_diff + + return loss + + +@LOSS_REGISTRY.register() +class PerceptualLoss(nn.Module): + """Perceptual loss with commonly used style loss. + + Args: + layer_weights (dict): The weight for each layer of vgg feature. + Here is an example: {'conv5_4': 1.}, which means the conv5_4 + feature layer (before relu5_4) will be extracted with weight + 1.0 in calculting losses. + vgg_type (str): The type of vgg network used as feature extractor. + Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image in vgg. + Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. + perceptual_weight (float): If `perceptual_weight > 0`, the perceptual + loss will be calculated and the loss will multiplied by the + weight. Default: 1.0. + style_weight (float): If `style_weight > 0`, the style loss will be + calculated and the loss will multiplied by the weight. + Default: 0. + criterion (str): Criterion used for perceptual loss. Default: 'l1'. + """ + + def __init__( + self, + layer_weights, + vgg_type="vgg19", + use_input_norm=True, + range_norm=False, + perceptual_weight=1.0, + style_weight=0.0, + criterion="l1", + ): + super(PerceptualLoss, self).__init__() + self.perceptual_weight = perceptual_weight + self.style_weight = style_weight + self.layer_weights = layer_weights + self.vgg = VGGFeatureExtractor( + layer_name_list=list(layer_weights.keys()), + vgg_type=vgg_type, + use_input_norm=use_input_norm, + range_norm=range_norm, + ) + + self.criterion_type = criterion + if self.criterion_type == "l1": + self.criterion = torch.nn.L1Loss() + elif self.criterion_type == "l2": + self.criterion = torch.nn.L2loss() + elif self.criterion_type == "mse": + self.criterion = torch.nn.MSELoss(reduction="mean") + elif self.criterion_type == "fro": + self.criterion = None + else: + raise NotImplementedError(f"{criterion} criterion has not been supported.") + + def forward(self, x, gt): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + gt (Tensor): Ground-truth tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + # extract vgg features + x_features = self.vgg(x) + gt_features = self.vgg(gt.detach()) + + # calculate perceptual loss + if self.perceptual_weight > 0: + percep_loss = 0 + for k in x_features.keys(): + if self.criterion_type == "fro": + percep_loss += torch.norm(x_features[k] - gt_features[k], p="fro") * self.layer_weights[k] + else: + percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] + percep_loss *= self.perceptual_weight + else: + percep_loss = None + + # calculate style loss + if self.style_weight > 0: + style_loss = 0 + for k in x_features.keys(): + if self.criterion_type == "fro": + style_loss += ( + torch.norm(self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p="fro") + * self.layer_weights[k] + ) + else: + style_loss += ( + self.criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k])) + * self.layer_weights[k] + ) + style_loss *= self.style_weight + else: + style_loss = None + + return percep_loss, style_loss + + def _gram_mat(self, x): + """Calculate Gram matrix. + + Args: + x (torch.Tensor): Tensor with shape of (n, c, h, w). + + Returns: + torch.Tensor: Gram matrix. + """ + n, c, h, w = x.size() + features = x.view(n, c, w * h) + features_t = features.transpose(1, 2) + gram = features.bmm(features_t) / (c * h * w) + return gram + + +@LOSS_REGISTRY.register() +class LPIPSLoss(nn.Module): + def __init__(self, loss_weight=1.0, use_input_norm=True, range_norm=False): + super(LPIPSLoss, self).__init__() + self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval() + self.loss_weight = loss_weight + self.use_input_norm = use_input_norm + self.range_norm = range_norm + + if self.use_input_norm: + # the mean is for image with range [0, 1] + self.register_buffer("mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + # the std is for image with range [0, 1] + self.register_buffer("std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + def forward(self, pred, target): + if self.range_norm: + pred = (pred + 1) / 2 + target = (target + 1) / 2 + if self.use_input_norm: + pred = (pred - self.mean) / self.std + target = (target - self.mean) / self.std + lpips_loss = self.perceptual(target.contiguous(), pred.contiguous()) + return self.loss_weight * lpips_loss.mean() + + +@LOSS_REGISTRY.register() +class GANLoss(nn.Module): + """Define GAN loss. + + Args: + gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. + real_label_val (float): The value for real label. Default: 1.0. + fake_label_val (float): The value for fake label. Default: 0.0. + loss_weight (float): Loss weight. Default: 1.0. + Note that loss_weight is only for generators; and it is always 1.0 + for discriminators. + """ + + def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): + super(GANLoss, self).__init__() + self.gan_type = gan_type + self.loss_weight = loss_weight + self.real_label_val = real_label_val + self.fake_label_val = fake_label_val + + if self.gan_type == "vanilla": + self.loss = nn.BCEWithLogitsLoss() + elif self.gan_type == "lsgan": + self.loss = nn.MSELoss() + elif self.gan_type == "wgan": + self.loss = self._wgan_loss + elif self.gan_type == "wgan_softplus": + self.loss = self._wgan_softplus_loss + elif self.gan_type == "hinge": + self.loss = nn.ReLU() + else: + raise NotImplementedError(f"GAN type {self.gan_type} is not implemented.") + + def _wgan_loss(self, input, target): + """wgan loss. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return -input.mean() if target else input.mean() + + def _wgan_softplus_loss(self, input, target): + """wgan loss with soft plus. softplus is a smooth approximation to the + ReLU function. + + In StyleGAN2, it is called: + Logistic loss for discriminator; + Non-saturating loss for generator. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return F.softplus(-input).mean() if target else F.softplus(input).mean() + + def get_target_label(self, input, target_is_real): + """Get target label. + + Args: + input (Tensor): Input tensor. + target_is_real (bool): Whether the target is real or fake. + + Returns: + (bool | Tensor): Target tensor. Return bool for wgan, otherwise, + return Tensor. + """ + + if self.gan_type in ["wgan", "wgan_softplus"]: + return target_is_real + target_val = self.real_label_val if target_is_real else self.fake_label_val + return input.new_ones(input.size()) * target_val + + def forward(self, input, target_is_real, is_disc=False): + """ + Args: + input (Tensor): The input for the loss module, i.e., the network + prediction. + target_is_real (bool): Whether the targe is real or fake. + is_disc (bool): Whether the loss for discriminators or not. + Default: False. + + Returns: + Tensor: GAN loss value. + """ + if self.gan_type == "hinge": + if is_disc: # for discriminators in hinge-gan + input = -input if target_is_real else input + loss = self.loss(1 + input).mean() + else: # for generators in hinge-gan + loss = -input.mean() + else: # other gan types + target_label = self.get_target_label(input, target_is_real) + loss = self.loss(input, target_label) + + # loss_weight is always 1.0 for discriminators + return loss if is_disc else loss * self.loss_weight + + +def r1_penalty(real_pred, real_img): + """R1 regularization for discriminator. The core idea is to + penalize the gradient on real data alone: when the + generator distribution produces the true data distribution + and the discriminator is equal to 0 on the data manifold, the + gradient penalty ensures that the discriminator cannot create + a non-zero gradient orthogonal to the data manifold without + suffering a loss in the GAN game. + + Ref: + Eq. 9 in Which training methods for GANs do actually converge. + """ + grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0] + grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() + return grad_penalty + + +def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): + noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3]) + grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0] + path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) + + path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) + + path_penalty = (path_lengths - path_mean).pow(2).mean() + + return path_penalty, path_lengths.detach().mean(), path_mean.detach() + + +def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None): + """Calculate gradient penalty for wgan-gp. + + Args: + discriminator (nn.Module): Network for the discriminator. + real_data (Tensor): Real input data. + fake_data (Tensor): Fake input data. + weight (Tensor): Weight tensor. Default: None. + + Returns: + Tensor: A tensor for gradient penalty. + """ + + batch_size = real_data.size(0) + alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1)) + + # interpolate between real_data and fake_data + interpolates = alpha * real_data + (1.0 - alpha) * fake_data + interpolates = autograd.Variable(interpolates, requires_grad=True) + + disc_interpolates = discriminator(interpolates) + gradients = autograd.grad( + outputs=disc_interpolates, + inputs=interpolates, + grad_outputs=torch.ones_like(disc_interpolates), + create_graph=True, + retain_graph=True, + only_inputs=True, + )[0] + + if weight is not None: + gradients = gradients * weight + + gradients_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() + if weight is not None: + gradients_penalty /= torch.mean(weight) + + return gradients_penalty diff --git a/hordelib/nodes/facerestore_cf/basicsr/metrics/__init__.py b/hordelib/nodes/facerestore_cf/basicsr/metrics/__init__.py new file mode 100644 index 00000000..5e9ae7c6 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/metrics/__init__.py @@ -0,0 +1,20 @@ +from copy import deepcopy + +from hordelib.nodes.facerestore_cf.basicsr.utils.registry import METRIC_REGISTRY + +from .psnr_ssim import calculate_psnr, calculate_ssim + +__all__ = ["calculate_psnr", "calculate_ssim"] + + +def calculate_metric(data, opt): + """Calculate metric from data and options. + + Args: + opt (dict): Configuration. It must constain: + type (str): Model type. + """ + opt = deepcopy(opt) + metric_type = opt.pop("type") + metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) + return metric diff --git a/hordelib/nodes/facerestore_cf/basicsr/metrics/metric_util.py b/hordelib/nodes/facerestore_cf/basicsr/metrics/metric_util.py new file mode 100644 index 00000000..c77e591e --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/metrics/metric_util.py @@ -0,0 +1,45 @@ +import numpy as np + +from hordelib.nodes.facerestore_cf.basicsr.utils.matlab_functions import bgr2ycbcr + + +def reorder_image(img, input_order="HWC"): + """Reorder images to 'HWC' order. + + If the input_order is (h, w), return (h, w, 1); + If the input_order is (c, h, w), return (h, w, c); + If the input_order is (h, w, c), return as it is. + + Args: + img (ndarray): Input image. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + If the input image shape is (h, w), input_order will not have + effects. Default: 'HWC'. + + Returns: + ndarray: reordered image. + """ + + if input_order not in ["HWC", "CHW"]: + raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are " "'HWC' and 'CHW'") + if len(img.shape) == 2: + img = img[..., None] + if input_order == "CHW": + img = img.transpose(1, 2, 0) + return img + + +def to_y_channel(img): + """Change to Y channel of YCbCr. + + Args: + img (ndarray): Images with range [0, 255]. + + Returns: + (ndarray): Images with range [0, 255] (float type) without round. + """ + img = img.astype(np.float32) / 255.0 + if img.ndim == 3 and img.shape[2] == 3: + img = bgr2ycbcr(img, y_only=True) + img = img[..., None] + return img * 255.0 diff --git a/hordelib/nodes/facerestore_cf/basicsr/metrics/psnr_ssim.py b/hordelib/nodes/facerestore_cf/basicsr/metrics/psnr_ssim.py new file mode 100644 index 00000000..e7735956 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/metrics/psnr_ssim.py @@ -0,0 +1,128 @@ +import cv2 +import numpy as np + +from hordelib.nodes.facerestore_cf.basicsr.metrics.metric_util import reorder_image, to_y_channel +from hordelib.nodes.facerestore_cf.basicsr.utils.registry import METRIC_REGISTRY + + +@METRIC_REGISTRY.register() +def calculate_psnr(img1, img2, crop_border, input_order="HWC", test_y_channel=False): + """Calculate PSNR (Peak Signal-to-Noise Ratio). + + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + + Args: + img1 (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the PSNR calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: psnr result. + """ + + assert img1.shape == img2.shape, f"Image shapes are differnet: {img1.shape}, {img2.shape}." + if input_order not in ["HWC", "CHW"]: + raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are " '"HWC" and "CHW"') + img1 = reorder_image(img1, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + if crop_border != 0: + img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img1 = to_y_channel(img1) + img2 = to_y_channel(img2) + + mse = np.mean((img1 - img2) ** 2) + if mse == 0: + return float("inf") + return 20.0 * np.log10(255.0 / np.sqrt(mse)) + + +def _ssim(img1, img2): + """Calculate SSIM (structural similarity) for one channel images. + + It is called by func:`calculate_ssim`. + + Args: + img1 (ndarray): Images with range [0, 255] with order 'HWC'. + img2 (ndarray): Images with range [0, 255] with order 'HWC'. + + Returns: + float: ssim result. + """ + + C1 = (0.01 * 255) ** 2 + C2 = (0.03 * 255) ** 2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +@METRIC_REGISTRY.register() +def calculate_ssim(img1, img2, crop_border, input_order="HWC", test_y_channel=False): + """Calculate SSIM (structural similarity). + + Ref: + Image quality assessment: From error visibility to structural similarity + + The results are the same as that of the official released MATLAB code in + https://ece.uwaterloo.ca/~z70wang/research/ssim/. + + For three-channel images, SSIM is calculated for each channel and then + averaged. + + Args: + img1 (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the SSIM calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: ssim result. + """ + + assert img1.shape == img2.shape, f"Image shapes are differnet: {img1.shape}, {img2.shape}." + if input_order not in ["HWC", "CHW"]: + raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are " '"HWC" and "CHW"') + img1 = reorder_image(img1, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + if crop_border != 0: + img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img1 = to_y_channel(img1) + img2 = to_y_channel(img2) + + ssims = [] + for i in range(img1.shape[2]): + ssims.append(_ssim(img1[..., i], img2[..., i])) + return np.array(ssims).mean() diff --git a/hordelib/nodes/facerestore_cf/basicsr/models/__init__.py b/hordelib/nodes/facerestore_cf/basicsr/models/__init__.py new file mode 100644 index 00000000..73c723c6 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/models/__init__.py @@ -0,0 +1,30 @@ +import importlib +from copy import deepcopy +from os import path as osp + +from hordelib.nodes.facerestore_cf.basicsr.utils import get_root_logger, scandir +from hordelib.nodes.facerestore_cf.basicsr.utils.registry import MODEL_REGISTRY + +__all__ = ["build_model"] + +# automatically scan and import model modules for registry +# scan all the files under the 'models' folder and collect files ending with +# '_model.py' +model_folder = osp.dirname(osp.abspath(__file__)) +model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith("_model.py")] +# import all the model modules +_model_modules = [importlib.import_module(f"basicsr.models.{file_name}") for file_name in model_filenames] + + +def build_model(opt): + """Build model from options. + + Args: + opt (dict): Configuration. It must constain: + model_type (str): Model type. + """ + opt = deepcopy(opt) + model = MODEL_REGISTRY.get(opt["model_type"])(opt) + logger = get_root_logger() + logger.info(f"Model [{model.__class__.__name__}] is created.") + return model diff --git a/hordelib/nodes/facerestore/facelib/__init__.py b/hordelib/nodes/facerestore_cf/basicsr/ops/__init__.py similarity index 100% rename from hordelib/nodes/facerestore/facelib/__init__.py rename to hordelib/nodes/facerestore_cf/basicsr/ops/__init__.py diff --git a/hordelib/nodes/facerestore_cf/basicsr/ops/dcn/__init__.py b/hordelib/nodes/facerestore_cf/basicsr/ops/dcn/__init__.py new file mode 100644 index 00000000..55b8b8e9 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/ops/dcn/__init__.py @@ -0,0 +1,17 @@ +from .deform_conv import ( + DeformConv, + DeformConvPack, + ModulatedDeformConv, + ModulatedDeformConvPack, + deform_conv, + modulated_deform_conv, +) + +__all__ = [ + "DeformConv", + "DeformConvPack", + "ModulatedDeformConv", + "ModulatedDeformConvPack", + "deform_conv", + "modulated_deform_conv", +] diff --git a/hordelib/nodes/facerestore_cf/basicsr/ops/dcn/deform_conv.py b/hordelib/nodes/facerestore_cf/basicsr/ops/dcn/deform_conv.py new file mode 100644 index 00000000..020c86d3 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/ops/dcn/deform_conv.py @@ -0,0 +1,503 @@ +import math + +import torch +from torch import nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn import functional as F +from torch.nn.modules.utils import _pair, _single + +try: + from . import deform_conv_ext +except ImportError: + import os + + BASICSR_JIT = os.getenv("BASICSR_JIT") + if BASICSR_JIT == "True": + from torch.utils.cpp_extension import load + + module_path = os.path.dirname(__file__) + deform_conv_ext = load( + "deform_conv", + sources=[ + os.path.join(module_path, "src", "deform_conv_ext.cpp"), + os.path.join(module_path, "src", "deform_conv_cuda.cpp"), + os.path.join(module_path, "src", "deform_conv_cuda_kernel.cu"), + ], + ) + + +class DeformConvFunction(Function): + + @staticmethod + def forward( + ctx, + input, + offset, + weight, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + im2col_step=64, + ): + if input is not None and input.dim() != 4: + raise ValueError(f"Expected 4D tensor as input, got {input.dim()}" "D tensor instead.") + ctx.stride = _pair(stride) + ctx.padding = _pair(padding) + ctx.dilation = _pair(dilation) + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.im2col_step = im2col_step + + ctx.save_for_backward(input, offset, weight) + + output = input.new_empty(DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride)) + + ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones + + if not input.is_cuda: + raise NotImplementedError + else: + cur_im2col_step = min(ctx.im2col_step, input.shape[0]) + assert (input.shape[0] % cur_im2col_step) == 0, "im2col step must divide batchsize" + deform_conv_ext.deform_conv_forward( + input, + weight, + offset, + output, + ctx.bufs_[0], + ctx.bufs_[1], + weight.size(3), + weight.size(2), + ctx.stride[1], + ctx.stride[0], + ctx.padding[1], + ctx.padding[0], + ctx.dilation[1], + ctx.dilation[0], + ctx.groups, + ctx.deformable_groups, + cur_im2col_step, + ) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, offset, weight = ctx.saved_tensors + + grad_input = grad_offset = grad_weight = None + + if not grad_output.is_cuda: + raise NotImplementedError + else: + cur_im2col_step = min(ctx.im2col_step, input.shape[0]) + assert (input.shape[0] % cur_im2col_step) == 0, "im2col step must divide batchsize" + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + deform_conv_ext.deform_conv_backward_input( + input, + offset, + grad_output, + grad_input, + grad_offset, + weight, + ctx.bufs_[0], + weight.size(3), + weight.size(2), + ctx.stride[1], + ctx.stride[0], + ctx.padding[1], + ctx.padding[0], + ctx.dilation[1], + ctx.dilation[0], + ctx.groups, + ctx.deformable_groups, + cur_im2col_step, + ) + + if ctx.needs_input_grad[2]: + grad_weight = torch.zeros_like(weight) + deform_conv_ext.deform_conv_backward_parameters( + input, + offset, + grad_output, + grad_weight, + ctx.bufs_[0], + ctx.bufs_[1], + weight.size(3), + weight.size(2), + ctx.stride[1], + ctx.stride[0], + ctx.padding[1], + ctx.padding[0], + ctx.dilation[1], + ctx.dilation[0], + ctx.groups, + ctx.deformable_groups, + 1, + cur_im2col_step, + ) + + return (grad_input, grad_offset, grad_weight, None, None, None, None, None) + + @staticmethod + def _output_size(input, weight, padding, dilation, stride): + channels = weight.size(0) + output_size = (input.size(0), channels) + for d in range(input.dim() - 2): + in_size = input.size(d + 2) + pad = padding[d] + kernel = dilation[d] * (weight.size(d + 2) - 1) + 1 + stride_ = stride[d] + output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1,) + if not all(map(lambda s: s > 0, output_size)): + raise ValueError("convolution input is too small (output would be " f'{"x".join(map(str, output_size))})') + return output_size + + +class ModulatedDeformConvFunction(Function): + + @staticmethod + def forward( + ctx, + input, + offset, + mask, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + ): + ctx.stride = stride + ctx.padding = padding + ctx.dilation = dilation + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.with_bias = bias is not None + if not ctx.with_bias: + bias = input.new_empty(1) # fake tensor + if not input.is_cuda: + raise NotImplementedError + if weight.requires_grad or mask.requires_grad or offset.requires_grad or input.requires_grad: + ctx.save_for_backward(input, offset, mask, weight, bias) + output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight)) + ctx._bufs = [input.new_empty(0), input.new_empty(0)] + deform_conv_ext.modulated_deform_conv_forward( + input, + weight, + bias, + ctx._bufs[0], + offset, + mask, + output, + ctx._bufs[1], + weight.shape[2], + weight.shape[3], + ctx.stride, + ctx.stride, + ctx.padding, + ctx.padding, + ctx.dilation, + ctx.dilation, + ctx.groups, + ctx.deformable_groups, + ctx.with_bias, + ) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + if not grad_output.is_cuda: + raise NotImplementedError + input, offset, mask, weight, bias = ctx.saved_tensors + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + grad_mask = torch.zeros_like(mask) + grad_weight = torch.zeros_like(weight) + grad_bias = torch.zeros_like(bias) + deform_conv_ext.modulated_deform_conv_backward( + input, + weight, + bias, + ctx._bufs[0], + offset, + mask, + ctx._bufs[1], + grad_input, + grad_weight, + grad_bias, + grad_offset, + grad_mask, + grad_output, + weight.shape[2], + weight.shape[3], + ctx.stride, + ctx.stride, + ctx.padding, + ctx.padding, + ctx.dilation, + ctx.dilation, + ctx.groups, + ctx.deformable_groups, + ctx.with_bias, + ) + if not ctx.with_bias: + grad_bias = None + + return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None) + + @staticmethod + def _infer_shape(ctx, input, weight): + n = input.size(0) + channels_out = weight.size(0) + height, width = input.shape[2:4] + kernel_h, kernel_w = weight.shape[2:4] + height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1 + width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1 + return n, channels_out, height_out, width_out + + +deform_conv = DeformConvFunction.apply +modulated_deform_conv = ModulatedDeformConvFunction.apply + + +class DeformConv(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=False, + ): + super(DeformConv, self).__init__() + + assert not bias + assert in_channels % groups == 0, f"in_channels {in_channels} is not divisible by groups {groups}" + assert out_channels % groups == 0, f"out_channels {out_channels} is not divisible " f"by groups {groups}" + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + self.deformable_groups = deformable_groups + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) + + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size)) + + self.reset_parameters() + + def reset_parameters(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1.0 / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + + def forward(self, x, offset): + # To fix an assert error in deform_conv_cuda.cpp:128 + # input image is smaller than kernel + input_pad = x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1] + if input_pad: + pad_h = max(self.kernel_size[0] - x.size(2), 0) + pad_w = max(self.kernel_size[1] - x.size(3), 0) + x = F.pad(x, (0, pad_w, 0, pad_h), "constant", 0).contiguous() + offset = F.pad(offset, (0, pad_w, 0, pad_h), "constant", 0).contiguous() + out = deform_conv( + x, + offset, + self.weight, + self.stride, + self.padding, + self.dilation, + self.groups, + self.deformable_groups, + ) + if input_pad: + out = out[:, :, : out.size(2) - pad_h, : out.size(3) - pad_w].contiguous() + return out + + +class DeformConvPack(DeformConv): + """A Deformable Conv Encapsulation that acts as normal Conv layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + _version = 2 + + def __init__(self, *args, **kwargs): + super(DeformConvPack, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + dilation=_pair(self.dilation), + bias=True, + ) + self.init_offset() + + def init_offset(self): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + offset = self.conv_offset(x) + return deform_conv( + x, + offset, + self.weight, + self.stride, + self.padding, + self.dilation, + self.groups, + self.deformable_groups, + ) + + +class ModulatedDeformConv(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=True, + ): + super(ModulatedDeformConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.deformable_groups = deformable_groups + self.with_bias = bias + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) + + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter("bias", None) + self.init_weights() + + def init_weights(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1.0 / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.zero_() + + def forward(self, x, offset, mask): + return modulated_deform_conv( + x, + offset, + mask, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + self.deformable_groups, + ) + + +class ModulatedDeformConvPack(ModulatedDeformConv): + """A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + _version = 2 + + def __init__(self, *args, **kwargs): + super(ModulatedDeformConvPack, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + dilation=_pair(self.dilation), + bias=True, + ) + self.init_weights() + + def init_weights(self): + super(ModulatedDeformConvPack, self).init_weights() + if hasattr(self, "conv_offset"): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + out = self.conv_offset(x) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + return modulated_deform_conv( + x, + offset, + mask, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + self.deformable_groups, + ) diff --git a/hordelib/nodes/facerestore_cf/basicsr/ops/dcn/src/deform_conv_cuda.cpp b/hordelib/nodes/facerestore_cf/basicsr/ops/dcn/src/deform_conv_cuda.cpp new file mode 100644 index 00000000..6fbef833 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/ops/dcn/src/deform_conv_cuda.cpp @@ -0,0 +1,685 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include +#include + +#include +#include + +void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor data_col); + +void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor grad_im); + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const int channels, const int height, + const int width, const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, at::Tensor grad_offset); + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + at::Tensor data_col); + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + at::Tensor grad_im); + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deformable_group, at::Tensor grad_offset, + at::Tensor grad_mask); + +void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput, + at::Tensor weight, int kH, int kW, int dH, int dW, int padH, + int padW, int dilationH, int dilationW, int group, + int deformable_group) { + TORCH_CHECK(weight.ndimension() == 4, + "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, " + "but got: %s", + weight.ndimension()); + + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + TORCH_CHECK(kW > 0 && kH > 0, + "kernel size should be greater than zero, but got kH: %d kW: %d", kH, + kW); + + TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW), + "kernel size should be consistent with weight, ", + "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH, + kW, weight.size(2), weight.size(3)); + + TORCH_CHECK(dW > 0 && dH > 0, + "stride should be greater than zero, but got dH: %d dW: %d", dH, dW); + + TORCH_CHECK( + dilationW > 0 && dilationH > 0, + "dilation should be greater than 0, but got dilationH: %d dilationW: %d", + dilationH, dilationW); + + int ndim = input.ndimension(); + int dimf = 0; + int dimh = 1; + int dimw = 2; + + if (ndim == 4) { + dimf++; + dimh++; + dimw++; + } + + TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s", + ndim); + + long nInputPlane = weight.size(1) * group; + long inputHeight = input.size(dimh); + long inputWidth = input.size(dimw); + long nOutputPlane = weight.size(0); + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + + TORCH_CHECK(nInputPlane % deformable_group == 0, + "input channels must divide deformable group size"); + + if (outputWidth < 1 || outputHeight < 1) + AT_ERROR( + "Given input size: (%ld x %ld x %ld). " + "Calculated output size: (%ld x %ld x %ld). Output size is too small", + nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight, + outputWidth); + + TORCH_CHECK(input.size(1) == nInputPlane, + "invalid number of input planes, expected: %d, but got: %d", + nInputPlane, input.size(1)); + + TORCH_CHECK((inputHeight >= kH && inputWidth >= kW), + "input image is smaller than kernel"); + + TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth), + "invalid spatial size of offset, expected height: %d width: %d, but " + "got height: %d width: %d", + outputHeight, outputWidth, offset.size(2), offset.size(3)); + + TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW), + "invalid number of channels of offset"); + + if (gradOutput != NULL) { + TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane, + "invalid number of gradOutput planes, expected: %d, but got: %d", + nOutputPlane, gradOutput->size(dimf)); + + TORCH_CHECK((gradOutput->size(dimh) == outputHeight && + gradOutput->size(dimw) == outputWidth), + "invalid size of gradOutput, expected height: %d width: %d , but " + "got height: %d width: %d", + outputHeight, outputWidth, gradOutput->size(dimh), + gradOutput->size(dimw)); + } +} + +int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step) { + // todo: resize columns to include im2col: done + // todo: add im2col_step as input + // todo: add new output buffer and transpose it to output (or directly + // transpose output) todo: possibly change data indexing because of + // parallel_imgs + + shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW, + dilationH, dilationW, group, deformable_group); + at::DeviceGuard guard(input.device()); + + input = input.contiguous(); + offset = offset.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input.unsqueeze_(0); + offset.unsqueeze_(0); + } + + // todo: assert batchsize dividable by im2col_step + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane, + outputHeight, outputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < outputHeight * outputWidth) { + ones = at::ones({outputHeight, outputWidth}, input.options()); + } + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + at::Tensor output_buffer = + at::zeros({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}, + output.options()); + + output_buffer = output_buffer.view( + {output_buffer.size(0), group, output_buffer.size(1) / group, + output_buffer.size(2), output_buffer.size(3)}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) { + output_buffer[elt][g] = output_buffer[elt][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output_buffer[elt][g]); + } + } + + output_buffer = output_buffer.view( + {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2), + output_buffer.size(3), output_buffer.size(4)}); + + output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step, outputHeight, outputWidth}); + output_buffer.transpose_(1, 2); + output.copy_(output_buffer); + output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + output = output.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step) { + shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW, + dilationH, dilationW, group, deformable_group); + at::DeviceGuard guard(input.device()); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view({1, input.size(0), input.size(1), input.size(2)}); + offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)}); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset"); + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + // change order of grad output + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, + outputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + // divide into groups + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), group, gradOutput.size(1) / group, + gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)}); + + for (int g = 0; g < group; g++) { + columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + gradOutput[elt][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2), + gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)}); + + deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane, + inputHeight, inputWidth, kH, kW, padH, padW, dH, dW, + dilationH, dilationW, im2col_step, deformable_group, + gradOffset[elt]); + + deformable_col2im(columns, offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, gradInput[elt]); + } + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + gradOffset = gradOffset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + gradOffset = + gradOffset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_parameters_cuda( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step) { + // todo: transpose and reshape outGrad + // todo: reshape columns + // todo: add im2col_step as input + + shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH, + padW, dilationH, dilationW, group, deformable_group); + at::DeviceGuard guard(input.device()); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view( + at::IntList({1, input.size(0), input.size(1), input.size(2)})); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = gradWeight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + at::Tensor gradOutputBuffer = at::zeros_like(gradOutput); + gradOutputBuffer = + gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step, + outputHeight, outputWidth}); + gradOutputBuffer.copy_(gradOutput); + gradOutputBuffer = + gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}); + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + + // divide into group + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group, + gradOutputBuffer.size(2), gradOutputBuffer.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + gradWeight = + gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3)}); + + for (int g = 0; g < group; g++) { + gradWeight[g] = gradWeight[g] + .flatten(1) + .addmm_(gradOutputBuffer[elt][g].flatten(1), + columns[g].transpose(1, 0), 1.0, scale) + .view_as(gradWeight[g]); + } + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), + gradOutputBuffer.size(1) * gradOutputBuffer.size(2), + gradOutputBuffer.size(3), gradOutputBuffer.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3), + gradWeight.size(4)}); + } + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + } + + return 1; +} + +void modulated_deform_conv_cuda_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias) { + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + at::DeviceGuard guard(input.device()); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_out = weight.size(0); + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + // resize output + output = output.view({batch, channels_out, height_out, width_out}).zero_(); + // resize temporary columns + columns = + at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, + input.options()); + + output = output.view({output.size(0), group, output.size(1) / group, + output.size(2), output.size(3)}); + + for (int b = 0; b < batch; b++) { + modulated_deformable_im2col_cuda( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + + // divide into group + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + + for (int g = 0; g < group; g++) { + output[b][g] = output[b][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output[b][g]); + } + + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + } + + output = output.view({output.size(0), output.size(1) * output.size(2), + output.size(3), output.size(4)}); + + if (with_bias) { + output += bias.view({1, bias.size(0), 1, 1}); + } +} + +void modulated_deform_conv_cuda_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias) { + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + at::DeviceGuard guard(input.device()); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + grad_input = grad_input.view({batch, channels, height, width}); + columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out}, + input.options()); + + grad_output = + grad_output.view({grad_output.size(0), group, grad_output.size(1) / group, + grad_output.size(2), grad_output.size(3)}); + + for (int b = 0; b < batch; b++) { + // divide int group + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) { + columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + grad_output[b][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + + // gradient w.r.t. input coordinate data + modulated_deformable_col2im_coord_cuda( + columns, input[b], offset[b], mask[b], 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b], + grad_mask[b]); + // gradient w.r.t. input data + modulated_deformable_col2im_cuda( + columns, offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, grad_input[b]); + + // gradient w.r.t. weight, dWeight should accumulate across the batch and + // group + modulated_deformable_im2col_cuda( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + grad_weight = grad_weight.view({group, grad_weight.size(0) / group, + grad_weight.size(1), grad_weight.size(2), + grad_weight.size(3)}); + if (with_bias) + grad_bias = grad_bias.view({group, grad_bias.size(0) / group}); + + for (int g = 0; g < group; g++) { + grad_weight[g] = + grad_weight[g] + .flatten(1) + .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1)) + .view_as(grad_weight[g]); + if (with_bias) { + grad_bias[g] = + grad_bias[g] + .view({-1, 1}) + .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1})) + .view(-1); + } + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1), + grad_weight.size(2), grad_weight.size(3), + grad_weight.size(4)}); + if (with_bias) + grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)}); + } + grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1), + grad_output.size(2), grad_output.size(3), + grad_output.size(4)}); +} diff --git a/hordelib/nodes/facerestore_cf/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu b/hordelib/nodes/facerestore_cf/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu new file mode 100644 index 00000000..9fe9ba3a --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu @@ -0,0 +1,867 @@ +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer ******************** + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.cuh + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1703.06211 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng + */ + +// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu + +#include +#include +#include +#include +#include +#include + +using namespace at; + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +const int kMaxGridNum = 65535; + +inline int GET_BLOCKS(const int N) +{ + return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); +} + +template +__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width, + const int height, const int width, scalar_t h, scalar_t w) +{ + + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, + const int h, const int w, const int height, const int width) +{ + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, + const int height, const int width, const scalar_t *im_data, + const int data_width, const int bp_dir) +{ + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const scalar_t map_h = i * dilation_h + offset_h; + //const scalar_t map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val; + data_col_ptr += batch_size * height_col * width_col; + } + } + } +} + +void deformable_im2col( + const at::Tensor data_im, const at::Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, const int ksize_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, at::Tensor data_col) +{ + // num_axes should be smaller than block size + // todo: check parallel_imgs is correctly passed in + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + channel_per_deformable_group, parallel_imgs, channels, deformable_group, + height_col, width_col, data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_im2col: %s\n", cudaGetErrorString(err)); + } +} + +template +__global__ void deformable_col2im_gpu_kernel( + const int n, const scalar_t *data_col, const scalar_t *data_offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * + 2 * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index]; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +void deformable_col2im( + const at::Tensor data_col, const at::Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor grad_im) +{ + + // todo: make sure parallel_imgs is passed in correctly + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, channels, height, width, ksize_h, + ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + parallel_imgs, deformable_group, height_col, width_col, grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_col2im: %s\n", cudaGetErrorString(err)); + } +} + +template +__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col, + const scalar_t *data_im, const scalar_t *data_offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, scalar_t *grad_offset) +{ + CUDA_KERNEL_LOOP(index, n) + { + scalar_t val = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * + batch_size * width_col * height_col; + const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * + channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * + kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + const scalar_t weight = get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos]; + cnt += 1; + } + + grad_offset[index] = val; + } +} + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, at::Tensor grad_offset) +{ + + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs; + int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + + deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col_, data_im_, data_offset_, channels, height, width, + ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group, + height_col, width_col, grad_offset_); + })); +} + +template +__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width, + const int height, const int width, scalar_t h, scalar_t w) +{ + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, + const int h, const int w, const int height, const int width) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, + const int height, const int width, const scalar_t *im_data, + const int data_width, const int bp_dir) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void modulated_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const float map_h = i * dilation_h + offset_h; + //const float map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + //data_col_ptr += height_col * width_col; + } + } + } +} + +template +__global__ void modulated_deformable_col2im_gpu_kernel(const int n, + const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index] * mask; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +template +__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, + const scalar_t *data_col, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_offset, scalar_t *grad_mask) +{ + CUDA_KERNEL_LOOP(index, n) + { + scalar_t val = 0, mval = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; + const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + else + { + mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); + } + const scalar_t weight = dmcn_get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos] * mask; + cnt += 1; + } + // KERNEL_ASSIGN(grad_offset[index], offset_req, val); + grad_offset[index] = val; + if (offset_c % 2 == 0) + // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); + grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; + } +} + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, at::Tensor data_col) +{ + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + modulated_deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, + batch_size, channels, deformable_group, height_col, width_col, data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, at::Tensor grad_im) +{ + + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + modulated_deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, deformable_group, height_col, width_col, grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, + at::Tensor grad_offset, at::Tensor grad_mask) +{ + const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; + const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + scalar_t *grad_mask_ = grad_mask.data_ptr(); + + modulated_deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, + grad_offset_, grad_mask_); + })); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); + } +} diff --git a/hordelib/nodes/facerestore_cf/basicsr/ops/dcn/src/deform_conv_ext.cpp b/hordelib/nodes/facerestore_cf/basicsr/ops/dcn/src/deform_conv_ext.cpp new file mode 100644 index 00000000..5c21d02c --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/ops/dcn/src/deform_conv_ext.cpp @@ -0,0 +1,164 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include +#include + +#include +#include + +#define WITH_CUDA // always use cuda +#ifdef WITH_CUDA +int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step); + +int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step); + +int deform_conv_backward_parameters_cuda( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step); + +void modulated_deform_conv_cuda_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias); + +void modulated_deform_conv_cuda_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias); +#endif + +int deform_conv_forward(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_forward_cuda(input, weight, offset, output, columns, + ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group, + deformable_group, im2col_step); +#else + AT_ERROR("deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("deform conv is not implemented on CPU"); +} + +int deform_conv_backward_input(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_backward_input_cuda(input, offset, gradOutput, + gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH, + dilationW, dilationH, group, deformable_group, im2col_step); +#else + AT_ERROR("deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("deform conv is not implemented on CPU"); +} + +int deform_conv_backward_parameters( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_backward_parameters_cuda(input, offset, gradOutput, + gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW, + dilationH, group, deformable_group, scale, im2col_step); +#else + AT_ERROR("deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("deform conv is not implemented on CPU"); +} + +void modulated_deform_conv_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return modulated_deform_conv_cuda_forward(input, weight, bias, ones, + offset, mask, output, columns, kernel_h, kernel_w, stride_h, + stride_w, pad_h, pad_w, dilation_h, dilation_w, group, + deformable_group, with_bias); +#else + AT_ERROR("modulated deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("modulated deform conv is not implemented on CPU"); +} + +void modulated_deform_conv_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return modulated_deform_conv_cuda_backward(input, weight, bias, ones, + offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset, + grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, deformable_group, + with_bias); +#else + AT_ERROR("modulated deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("modulated deform conv is not implemented on CPU"); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("deform_conv_forward", &deform_conv_forward, + "deform forward"); + m.def("deform_conv_backward_input", &deform_conv_backward_input, + "deform_conv_backward_input"); + m.def("deform_conv_backward_parameters", + &deform_conv_backward_parameters, + "deform_conv_backward_parameters"); + m.def("modulated_deform_conv_forward", + &modulated_deform_conv_forward, + "modulated deform conv forward"); + m.def("modulated_deform_conv_backward", + &modulated_deform_conv_backward, + "modulated deform conv backward"); +} diff --git a/hordelib/nodes/facerestore_cf/basicsr/ops/fused_act/__init__.py b/hordelib/nodes/facerestore_cf/basicsr/ops/fused_act/__init__.py new file mode 100644 index 00000000..81f5b1e8 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/ops/fused_act/__init__.py @@ -0,0 +1,3 @@ +from .fused_act import FusedLeakyReLU, fused_leaky_relu + +__all__ = ["FusedLeakyReLU", "fused_leaky_relu"] diff --git a/hordelib/nodes/facerestore_cf/basicsr/ops/fused_act/fused_act.py b/hordelib/nodes/facerestore_cf/basicsr/ops/fused_act/fused_act.py new file mode 100644 index 00000000..bb8be954 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/ops/fused_act/fused_act.py @@ -0,0 +1,98 @@ +# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py + +import torch +from torch import nn +from torch.autograd import Function + +try: + from . import fused_act_ext +except ImportError: + import os + + BASICSR_JIT = os.getenv("BASICSR_JIT") + if BASICSR_JIT == "True": + from torch.utils.cpp_extension import load + + module_path = os.path.dirname(__file__) + fused_act_ext = load( + "fused", + sources=[ + os.path.join(module_path, "src", "fused_bias_act.cpp"), + os.path.join(module_path, "src", "fused_bias_act_kernel.cu"), + ], + ) + + +class FusedLeakyReLUFunctionBackward(Function): + + @staticmethod + def forward(ctx, grad_output, out, negative_slope, scale): + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + empty = grad_output.new_empty(0) + + grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale) + + dim = [0] + + if grad_input.ndim > 2: + dim += list(range(2, grad_input.ndim)) + + grad_bias = grad_input.sum(dim).detach() + + return grad_input, grad_bias + + @staticmethod + def backward(ctx, gradgrad_input, gradgrad_bias): + (out,) = ctx.saved_tensors + gradgrad_out = fused_act_ext.fused_bias_act( + gradgrad_input, + gradgrad_bias, + out, + 3, + 1, + ctx.negative_slope, + ctx.scale, + ) + + return gradgrad_out, None, None, None + + +class FusedLeakyReLUFunction(Function): + + @staticmethod + def forward(ctx, input, bias, negative_slope, scale): + empty = input.new_empty(0) + out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + return out + + @staticmethod + def backward(ctx, grad_output): + (out,) = ctx.saved_tensors + + grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale) + + return grad_input, grad_bias, None, None + + +class FusedLeakyReLU(nn.Module): + + def __init__(self, channel, negative_slope=0.2, scale=2**0.5): + super().__init__() + + self.bias = nn.Parameter(torch.zeros(channel)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5): + return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) diff --git a/hordelib/nodes/facerestore_cf/basicsr/ops/fused_act/src/fused_bias_act.cpp b/hordelib/nodes/facerestore_cf/basicsr/ops/fused_act/src/fused_bias_act.cpp new file mode 100644 index 00000000..c6225bbc --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/ops/fused_act/src/fused_bias_act.cpp @@ -0,0 +1,26 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp +#include + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, + const torch::Tensor& bias, + const torch::Tensor& refer, + int act, int grad, float alpha, float scale); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor fused_bias_act(const torch::Tensor& input, + const torch::Tensor& bias, + const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + CHECK_CUDA(input); + CHECK_CUDA(bias); + + return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); +} diff --git a/hordelib/nodes/facerestore_cf/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu b/hordelib/nodes/facerestore_cf/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu new file mode 100644 index 00000000..31a536f9 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu @@ -0,0 +1,100 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +template +static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, + int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { + int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; + + scalar_t zero = 0.0; + + for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { + scalar_t x = p_x[xi]; + + if (use_bias) { + x += p_b[(xi / step_b) % size_b]; + } + + scalar_t ref = use_ref ? p_ref[xi] : zero; + + scalar_t y; + + switch (act * 10 + grad) { + default: + case 10: y = x; break; + case 11: y = x; break; + case 12: y = 0.0; break; + + case 30: y = (x > 0.0) ? x : x * alpha; break; + case 31: y = (ref > 0.0) ? x : x * alpha; break; + case 32: y = 0.0; break; + } + + out[xi] = y * scale; + } +} + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + auto x = input.contiguous(); + auto b = bias.contiguous(); + auto ref = refer.contiguous(); + + int use_bias = b.numel() ? 1 : 0; + int use_ref = ref.numel() ? 1 : 0; + + int size_x = x.numel(); + int size_b = b.numel(); + int step_b = 1; + + for (int i = 1 + 1; i < x.dim(); i++) { + step_b *= x.size(i); + } + + int loop_x = 4; + int block_size = 4 * 32; + int grid_size = (size_x - 1) / (loop_x * block_size) + 1; + + auto y = torch::empty_like(x); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { + fused_bias_act_kernel<<>>( + y.data_ptr(), + x.data_ptr(), + b.data_ptr(), + ref.data_ptr(), + act, + grad, + alpha, + scale, + loop_x, + size_x, + step_b, + size_b, + use_bias, + use_ref + ); + }); + + return y; +} diff --git a/hordelib/nodes/facerestore_cf/basicsr/ops/upfirdn2d/__init__.py b/hordelib/nodes/facerestore_cf/basicsr/ops/upfirdn2d/__init__.py new file mode 100644 index 00000000..c6fd35e4 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/ops/upfirdn2d/__init__.py @@ -0,0 +1,3 @@ +from .upfirdn2d import upfirdn2d + +__all__ = ["upfirdn2d"] diff --git a/hordelib/nodes/facerestore_cf/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp b/hordelib/nodes/facerestore_cf/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp new file mode 100644 index 00000000..12b56617 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp @@ -0,0 +1,24 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp +#include + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + CHECK_CUDA(input); + CHECK_CUDA(kernel); + + return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); +} diff --git a/hordelib/nodes/facerestore_cf/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu b/hordelib/nodes/facerestore_cf/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu new file mode 100644 index 00000000..e82913f5 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu @@ -0,0 +1,370 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + +static __host__ __device__ __forceinline__ int floor_div(int a, int b) { + int c = a / b; + + if (c * b > a) { + c--; + } + + return c; +} + +struct UpFirDn2DKernelParams { + int up_x; + int up_y; + int down_x; + int down_y; + int pad_x0; + int pad_x1; + int pad_y0; + int pad_y1; + + int major_dim; + int in_h; + int in_w; + int minor_dim; + int kernel_h; + int kernel_w; + int out_h; + int out_w; + int loop_major; + int loop_x; +}; + +template +__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; + int out_y = minor_idx / p.minor_dim; + minor_idx -= out_y * p.minor_dim; + int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; + int major_idx_base = blockIdx.z * p.loop_major; + + if (out_x_base >= p.out_w || out_y >= p.out_h || + major_idx_base >= p.major_dim) { + return; + } + + int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; + int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); + int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; + int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major && major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, out_x = out_x_base; + loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { + int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; + int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); + int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; + int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; + + const scalar_t *x_p = + &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + + minor_idx]; + const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; + int x_px = p.minor_dim; + int k_px = -p.up_x; + int x_py = p.in_w * p.minor_dim; + int k_py = -p.up_y * p.kernel_w; + + scalar_t v = 0.0f; + + for (int y = 0; y < h; y++) { + for (int x = 0; x < w; x++) { + v += static_cast(*x_p) * static_cast(*k_p); + x_p += x_px; + k_p += k_px; + } + + x_p += x_py - w * x_px; + k_p += k_py - w * k_px; + } + + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } +} + +template +__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; + const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; + + __shared__ volatile float sk[kernel_h][kernel_w]; + __shared__ volatile float sx[tile_in_h][tile_in_w]; + + int minor_idx = blockIdx.x; + int tile_out_y = minor_idx / p.minor_dim; + minor_idx -= tile_out_y * p.minor_dim; + tile_out_y *= tile_out_h; + int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; + int major_idx_base = blockIdx.z * p.loop_major; + + if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | + major_idx_base >= p.major_dim) { + return; + } + + for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; + tap_idx += blockDim.x) { + int ky = tap_idx / kernel_w; + int kx = tap_idx - ky * kernel_w; + scalar_t v = 0.0; + + if (kx < p.kernel_w & ky < p.kernel_h) { + v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; + } + + sk[ky][kx] = v; + } + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major & major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, tile_out_x = tile_out_x_base; + loop_x < p.loop_x & tile_out_x < p.out_w; + loop_x++, tile_out_x += tile_out_w) { + int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; + int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; + int tile_in_x = floor_div(tile_mid_x, up_x); + int tile_in_y = floor_div(tile_mid_y, up_y); + + __syncthreads(); + + for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; + in_idx += blockDim.x) { + int rel_in_y = in_idx / tile_in_w; + int rel_in_x = in_idx - rel_in_y * tile_in_w; + int in_x = rel_in_x + tile_in_x; + int in_y = rel_in_y + tile_in_y; + + scalar_t v = 0.0; + + if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { + v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * + p.minor_dim + + minor_idx]; + } + + sx[rel_in_y][rel_in_x] = v; + } + + __syncthreads(); + for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; + out_idx += blockDim.x) { + int rel_out_y = out_idx / tile_out_w; + int rel_out_x = out_idx - rel_out_y * tile_out_w; + int out_x = rel_out_x + tile_out_x; + int out_y = rel_out_y + tile_out_y; + + int mid_x = tile_mid_x + rel_out_x * down_x; + int mid_y = tile_mid_y + rel_out_y * down_y; + int in_x = floor_div(mid_x, up_x); + int in_y = floor_div(mid_y, up_y); + int rel_in_x = in_x - tile_in_x; + int rel_in_y = in_y - tile_in_y; + int kernel_x = (in_x + 1) * up_x - mid_x - 1; + int kernel_y = (in_y + 1) * up_y - mid_y - 1; + + scalar_t v = 0.0; + +#pragma unroll + for (int y = 0; y < kernel_h / up_y; y++) +#pragma unroll + for (int x = 0; x < kernel_w / up_x; x++) + v += sx[rel_in_y + y][rel_in_x + x] * + sk[kernel_y + y * up_y][kernel_x + x * up_x]; + + if (out_x < p.out_w & out_y < p.out_h) { + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } + } + } +} + +torch::Tensor upfirdn2d_op(const torch::Tensor &input, + const torch::Tensor &kernel, int up_x, int up_y, + int down_x, int down_y, int pad_x0, int pad_x1, + int pad_y0, int pad_y1) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + UpFirDn2DKernelParams p; + + auto x = input.contiguous(); + auto k = kernel.contiguous(); + + p.major_dim = x.size(0); + p.in_h = x.size(1); + p.in_w = x.size(2); + p.minor_dim = x.size(3); + p.kernel_h = k.size(0); + p.kernel_w = k.size(1); + p.up_x = up_x; + p.up_y = up_y; + p.down_x = down_x; + p.down_y = down_y; + p.pad_x0 = pad_x0; + p.pad_x1 = pad_x1; + p.pad_y0 = pad_y0; + p.pad_y1 = pad_y1; + + p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / + p.down_y; + p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / + p.down_x; + + auto out = + at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); + + int mode = -1; + + int tile_out_h = -1; + int tile_out_w = -1; + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 1; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 3 && p.kernel_w <= 3) { + mode = 2; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 3; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 4; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 5; + tile_out_h = 8; + tile_out_w = 32; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 6; + tile_out_h = 8; + tile_out_w = 32; + } + + dim3 block_size; + dim3 grid_size; + + if (tile_out_h > 0 && tile_out_w > 0) { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 1; + block_size = dim3(32 * 8, 1, 1); + grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, + (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } else { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 4; + block_size = dim3(4, 32, 1); + grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, + (p.out_w - 1) / (p.loop_x * block_size.y) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { + switch (mode) { + case 1: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 2: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 3: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 4: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 5: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 6: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + default: + upfirdn2d_kernel_large<<>>( + out.data_ptr(), x.data_ptr(), + k.data_ptr(), p); + } + }); + + return out; +} diff --git a/hordelib/nodes/facerestore_cf/basicsr/ops/upfirdn2d/upfirdn2d.py b/hordelib/nodes/facerestore_cf/basicsr/ops/upfirdn2d/upfirdn2d.py new file mode 100644 index 00000000..59a9411e --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/ops/upfirdn2d/upfirdn2d.py @@ -0,0 +1,188 @@ +# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py + +import torch +from torch.autograd import Function +from torch.nn import functional as F + +try: + from . import upfirdn2d_ext +except ImportError: + import os + + BASICSR_JIT = os.getenv("BASICSR_JIT") + if BASICSR_JIT == "True": + from torch.utils.cpp_extension import load + + module_path = os.path.dirname(__file__) + upfirdn2d_ext = load( + "upfirdn2d", + sources=[ + os.path.join(module_path, "src", "upfirdn2d.cpp"), + os.path.join(module_path, "src", "upfirdn2d_kernel.cu"), + ], + ) + + +class UpFirDn2dBackward(Function): + + @staticmethod + def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size): + + up_x, up_y = up + down_x, down_y = down + g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad + + grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) + + grad_input = upfirdn2d_ext.upfirdn2d( + grad_output, + grad_kernel, + down_x, + down_y, + up_x, + up_y, + g_pad_x0, + g_pad_x1, + g_pad_y0, + g_pad_y1, + ) + grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) + + ctx.save_for_backward(kernel) + + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + ctx.up_x = up_x + ctx.up_y = up_y + ctx.down_x = down_x + ctx.down_y = down_y + ctx.pad_x0 = pad_x0 + ctx.pad_x1 = pad_x1 + ctx.pad_y0 = pad_y0 + ctx.pad_y1 = pad_y1 + ctx.in_size = in_size + ctx.out_size = out_size + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_input): + (kernel,) = ctx.saved_tensors + + gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) + + gradgrad_out = upfirdn2d_ext.upfirdn2d( + gradgrad_input, + kernel, + ctx.up_x, + ctx.up_y, + ctx.down_x, + ctx.down_y, + ctx.pad_x0, + ctx.pad_x1, + ctx.pad_y0, + ctx.pad_y1, + ) + # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], + # ctx.out_size[1], ctx.in_size[3]) + gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]) + + return gradgrad_out, None, None, None, None, None, None, None, None + + +class UpFirDn2d(Function): + + @staticmethod + def forward(ctx, input, kernel, up, down, pad): + up_x, up_y = up + down_x, down_y = down + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + kernel_h, kernel_w = kernel.shape + batch, channel, in_h, in_w = input.shape + ctx.in_size = input.shape + + input = input.reshape(-1, in_h, in_w, 1) + + ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + ctx.out_size = (out_h, out_w) + + ctx.up = (up_x, up_y) + ctx.down = (down_x, down_y) + ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) + + g_pad_x0 = kernel_w - pad_x0 - 1 + g_pad_y0 = kernel_h - pad_y0 - 1 + g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 + g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 + + ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) + + out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1) + # out = out.view(major, out_h, out_w, minor) + out = out.view(-1, channel, out_h, out_w) + + return out + + @staticmethod + def backward(ctx, grad_output): + kernel, grad_kernel = ctx.saved_tensors + + grad_input = UpFirDn2dBackward.apply( + grad_output, + kernel, + grad_kernel, + ctx.up, + ctx.down, + ctx.pad, + ctx.g_pad, + ctx.in_size, + ctx.out_size, + ) + + return grad_input, None, None, None, None + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + if input.device.type == "cpu": + out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + else: + out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])) + + return out + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[:, max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), :] + + out = out.permute(0, 3, 1, 2) + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) diff --git a/hordelib/nodes/facerestore_cf/basicsr/setup.py b/hordelib/nodes/facerestore_cf/basicsr/setup.py new file mode 100644 index 00000000..b8755f97 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/setup.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python + +import os +import subprocess +import sys +import time + +import torch +from setuptools import find_packages, setup +from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension + +version_file = "./basicsr/version.py" + + +def readme(): + with open("README.md", encoding="utf-8") as f: + content = f.read() + return content + + +def get_git_hash(): + + def _minimal_ext_cmd(cmd): + # construct minimal environment + env = {} + for k in ["SYSTEMROOT", "PATH", "HOME"]: + v = os.environ.get(k) + if v is not None: + env[k] = v + # LANGUAGE is used on win32 + env["LANGUAGE"] = "C" + env["LANG"] = "C" + env["LC_ALL"] = "C" + out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0] + return out + + try: + out = _minimal_ext_cmd(["git", "rev-parse", "HEAD"]) + sha = out.strip().decode("ascii") + except OSError: + sha = "unknown" + + return sha + + +def get_hash(): + if os.path.exists(".git"): + sha = get_git_hash()[:7] + elif os.path.exists(version_file): + try: + from version import __version__ + + sha = __version__.split("+")[-1] + except ImportError: + raise ImportError("Unable to get git version") + else: + sha = "unknown" + + return sha + + +def write_version_py(): + content = """# GENERATED VERSION FILE +# TIME: {} +__version__ = '{}' +__gitsha__ = '{}' +version_info = ({}) +""" + sha = get_hash() + with open("./basicsr/VERSION") as f: + SHORT_VERSION = f.read().strip() + VERSION_INFO = ", ".join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split(".")]) + + version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO) + with open(version_file, "w") as f: + f.write(version_file_str) + + +def get_version(): + with open(version_file) as f: + exec(compile(f.read(), version_file, "exec")) + return locals()["__version__"] + + +def make_cuda_ext(name, module, sources, sources_cuda=None): + if sources_cuda is None: + sources_cuda = [] + define_macros = [] + extra_compile_args = {"cxx": []} + + if torch.cuda.is_available() or os.getenv("FORCE_CUDA", "0") == "1": + define_macros += [("WITH_CUDA", None)] + extension = CUDAExtension + extra_compile_args["nvcc"] = [ + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + ] + sources += sources_cuda + else: + print(f"Compiling {name} without CUDA") + extension = CppExtension + + return extension( + name=f"{module}.{name}", + sources=[os.path.join(*module.split("."), p) for p in sources], + define_macros=define_macros, + extra_compile_args=extra_compile_args, + ) + + +def get_requirements(filename="requirements.txt"): + with open(os.path.join(".", filename)) as f: + requires = [line.replace("\n", "") for line in f.readlines()] + return requires + + +if __name__ == "__main__": + if "--cuda_ext" in sys.argv: + ext_modules = [ + make_cuda_ext( + name="deform_conv_ext", + module="ops.dcn", + sources=["src/deform_conv_ext.cpp"], + sources_cuda=["src/deform_conv_cuda.cpp", "src/deform_conv_cuda_kernel.cu"], + ), + make_cuda_ext( + name="fused_act_ext", + module="ops.fused_act", + sources=["src/fused_bias_act.cpp"], + sources_cuda=["src/fused_bias_act_kernel.cu"], + ), + make_cuda_ext( + name="upfirdn2d_ext", + module="ops.upfirdn2d", + sources=["src/upfirdn2d.cpp"], + sources_cuda=["src/upfirdn2d_kernel.cu"], + ), + ] + sys.argv.remove("--cuda_ext") + else: + ext_modules = [] + + write_version_py() + setup( + name="basicsr", + version=get_version(), + description="Open Source Image and Video Super-Resolution Toolbox", + long_description=readme(), + long_description_content_type="text/markdown", + author="Xintao Wang", + author_email="xintao.wang@outlook.com", + keywords="computer vision, restoration, super resolution", + url="https://github.com/xinntao/BasicSR", + include_package_data=True, + packages=find_packages(exclude=("options", "datasets", "experiments", "results", "tb_logger", "wandb")), + classifiers=[ + "Development Status :: 4 - Beta", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + ], + license="Apache License 2.0", + setup_requires=["cython", "numpy"], + install_requires=get_requirements(), + ext_modules=ext_modules, + cmdclass={"build_ext": BuildExtension}, + zip_safe=False, + ) diff --git a/hordelib/nodes/facerestore_cf/basicsr/train.py b/hordelib/nodes/facerestore_cf/basicsr/train.py new file mode 100644 index 00000000..7e84af31 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/train.py @@ -0,0 +1,250 @@ +import argparse +import datetime +import logging +import math +import random +import time +import warnings +from os import path as osp + +import torch + +from hordelib.nodes.facerestore_cf.basicsr.data import build_dataloader, build_dataset +from hordelib.nodes.facerestore_cf.basicsr.data.data_sampler import EnlargedSampler +from hordelib.nodes.facerestore_cf.basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher +from hordelib.nodes.facerestore_cf.basicsr.models import build_model +from hordelib.nodes.facerestore_cf.basicsr.utils import ( + MessageLogger, + check_resume, + get_env_info, + get_root_logger, + init_tb_logger, + init_wandb_logger, + make_exp_dirs, + mkdir_and_rename, + set_random_seed, +) +from hordelib.nodes.facerestore_cf.basicsr.utils.dist_util import get_dist_info, init_dist +from hordelib.nodes.facerestore_cf.basicsr.utils.options import dict2str, parse + +# ignore UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. +warnings.filterwarnings("ignore", category=UserWarning) + + +def parse_options(root_path, is_train=True): + parser = argparse.ArgumentParser() + parser.add_argument("-opt", type=str, required=True, help="Path to option YAML file.") + parser.add_argument("--launcher", choices=["none", "pytorch", "slurm"], default="none", help="job launcher") + parser.add_argument("--local_rank", type=int, default=0) + args = parser.parse_args() + opt = parse(args.opt, root_path, is_train=is_train) + + # distributed settings + if args.launcher == "none": + opt["dist"] = False + print("Disable distributed.", flush=True) + else: + opt["dist"] = True + if args.launcher == "slurm" and "dist_params" in opt: + init_dist(args.launcher, **opt["dist_params"]) + else: + init_dist(args.launcher) + + opt["rank"], opt["world_size"] = get_dist_info() + + # random seed + seed = opt.get("manual_seed") + if seed is None: + seed = random.randint(1, 10000) + opt["manual_seed"] = seed + set_random_seed(seed + opt["rank"]) + + return opt + + +def init_loggers(opt): + log_file = osp.join(opt["path"]["log"], f"train_{opt['name']}.log") + logger = get_root_logger(logger_name="basicsr", log_level=logging.INFO, log_file=log_file) + logger.info(get_env_info()) + logger.info(dict2str(opt)) + + # initialize wandb logger before tensorboard logger to allow proper sync: + if (opt["logger"].get("wandb") is not None) and (opt["logger"]["wandb"].get("project") is not None): + assert opt["logger"].get("use_tb_logger") is True, "should turn on tensorboard when using wandb" + init_wandb_logger(opt) + tb_logger = None + if opt["logger"].get("use_tb_logger"): + tb_logger = init_tb_logger(log_dir=osp.join("tb_logger", opt["name"])) + return logger, tb_logger + + +def create_train_val_dataloader(opt, logger): + # create train and val dataloaders + train_loader, val_loader = None, None + for phase, dataset_opt in opt["datasets"].items(): + if phase == "train": + dataset_enlarge_ratio = dataset_opt.get("dataset_enlarge_ratio", 1) + train_set = build_dataset(dataset_opt) + train_sampler = EnlargedSampler(train_set, opt["world_size"], opt["rank"], dataset_enlarge_ratio) + train_loader = build_dataloader( + train_set, + dataset_opt, + num_gpu=opt["num_gpu"], + dist=opt["dist"], + sampler=train_sampler, + seed=opt["manual_seed"], + ) + + num_iter_per_epoch = math.ceil( + len(train_set) * dataset_enlarge_ratio / (dataset_opt["batch_size_per_gpu"] * opt["world_size"]), + ) + total_iters = int(opt["train"]["total_iter"]) + total_epochs = math.ceil(total_iters / (num_iter_per_epoch)) + logger.info( + "Training statistics:" + f"\n\tNumber of train images: {len(train_set)}" + f"\n\tDataset enlarge ratio: {dataset_enlarge_ratio}" + f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}' + f'\n\tWorld size (gpu number): {opt["world_size"]}' + f"\n\tRequire iter number per epoch: {num_iter_per_epoch}" + f"\n\tTotal epochs: {total_epochs}; iters: {total_iters}.", + ) + + elif phase == "val": + val_set = build_dataset(dataset_opt) + val_loader = build_dataloader( + val_set, + dataset_opt, + num_gpu=opt["num_gpu"], + dist=opt["dist"], + sampler=None, + seed=opt["manual_seed"], + ) + logger.info(f'Number of val images/folders in {dataset_opt["name"]}: ' f"{len(val_set)}") + else: + raise ValueError(f"Dataset phase {phase} is not recognized.") + + return train_loader, train_sampler, val_loader, total_epochs, total_iters + + +def train_pipeline(root_path): + # parse options, set distributed setting, set ramdom seed + opt = parse_options(root_path, is_train=True) + + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + # load resume states if necessary + if opt["path"].get("resume_state"): + device_id = torch.cuda.current_device() + resume_state = torch.load( + opt["path"]["resume_state"], + map_location=lambda storage, loc: storage.cuda(device_id), + ) + else: + resume_state = None + + # mkdir for experiments and logger + if resume_state is None: + make_exp_dirs(opt) + if opt["logger"].get("use_tb_logger") and opt["rank"] == 0: + mkdir_and_rename(osp.join("tb_logger", opt["name"])) + + # initialize loggers + logger, tb_logger = init_loggers(opt) + + # create train and validation dataloaders + result = create_train_val_dataloader(opt, logger) + train_loader, train_sampler, val_loader, total_epochs, total_iters = result + + # create model + if resume_state: # resume training + check_resume(opt, resume_state["iter"]) + model = build_model(opt) + model.resume_training(resume_state) # handle optimizers and schedulers + logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " f"iter: {resume_state['iter']}.") + start_epoch = resume_state["epoch"] + current_iter = resume_state["iter"] + else: + model = build_model(opt) + start_epoch = 0 + current_iter = 0 + + # create message logger (formatted outputs) + msg_logger = MessageLogger(opt, current_iter, tb_logger) + + # dataloader prefetcher + prefetch_mode = opt["datasets"]["train"].get("prefetch_mode") + if prefetch_mode is None or prefetch_mode == "cpu": + prefetcher = CPUPrefetcher(train_loader) + elif prefetch_mode == "cuda": + prefetcher = CUDAPrefetcher(train_loader, opt) + logger.info(f"Use {prefetch_mode} prefetch dataloader") + if opt["datasets"]["train"].get("pin_memory") is not True: + raise ValueError("Please set pin_memory=True for CUDAPrefetcher.") + else: + raise ValueError(f"Wrong prefetch_mode {prefetch_mode}." "Supported ones are: None, 'cuda', 'cpu'.") + + # training + logger.info(f"Start training from epoch: {start_epoch}, iter: {current_iter+1}") + data_time, iter_time = time.time(), time.time() + start_time = time.time() + + for epoch in range(start_epoch, total_epochs + 1): + train_sampler.set_epoch(epoch) + prefetcher.reset() + train_data = prefetcher.next() + + while train_data is not None: + data_time = time.time() - data_time + + current_iter += 1 + if current_iter > total_iters: + break + # update learning rate + model.update_learning_rate(current_iter, warmup_iter=opt["train"].get("warmup_iter", -1)) + # training + model.feed_data(train_data) + model.optimize_parameters(current_iter) + iter_time = time.time() - iter_time + # log + if current_iter % opt["logger"]["print_freq"] == 0: + log_vars = {"epoch": epoch, "iter": current_iter} + log_vars.update({"lrs": model.get_current_learning_rate()}) + log_vars.update({"time": iter_time, "data_time": data_time}) + log_vars.update(model.get_current_log()) + msg_logger(log_vars) + + # save models and training states + if current_iter % opt["logger"]["save_checkpoint_freq"] == 0: + logger.info("Saving models and training states.") + model.save(epoch, current_iter) + + # validation + if ( + opt.get("val") is not None + and opt["datasets"].get("val") is not None + and (current_iter % opt["val"]["val_freq"] == 0) + ): + model.validation(val_loader, current_iter, tb_logger, opt["val"]["save_img"]) + + data_time = time.time() + iter_time = time.time() + train_data = prefetcher.next() + # end of iter + + # end of epoch + + consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time))) + logger.info(f"End of training. Time consumed: {consumed_time}") + logger.info("Save the latest model.") + model.save(epoch=-1, current_iter=-1) # -1 stands for the latest + if opt.get("val") is not None and opt["datasets"].get("val"): + model.validation(val_loader, current_iter, tb_logger, opt["val"]["save_img"]) + if tb_logger: + tb_logger.close() + + +if __name__ == "__main__": + root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) + train_pipeline(root_path) diff --git a/hordelib/nodes/facerestore_cf/basicsr/utils/__init__.py b/hordelib/nodes/facerestore_cf/basicsr/utils/__init__.py new file mode 100644 index 00000000..b1d60bc8 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/utils/__init__.py @@ -0,0 +1,29 @@ +from .file_client import FileClient +from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img +from .logger import MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger +from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt + +__all__ = [ + # file_client.py + "FileClient", + # img_util.py + "img2tensor", + "tensor2img", + "imfrombytes", + "imwrite", + "crop_border", + # logger.py + "MessageLogger", + "init_tb_logger", + "init_wandb_logger", + "get_root_logger", + "get_env_info", + # misc.py + "set_random_seed", + "get_time_str", + "mkdir_and_rename", + "make_exp_dirs", + "scandir", + "check_resume", + "sizeof_fmt", +] diff --git a/hordelib/nodes/facerestore_cf/basicsr/utils/dist_util.py b/hordelib/nodes/facerestore_cf/basicsr/utils/dist_util.py new file mode 100644 index 00000000..a252a461 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/utils/dist_util.py @@ -0,0 +1,83 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py +import functools +import os +import subprocess + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + + +def init_dist(launcher, backend="nccl", **kwargs): + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method("spawn") + if launcher == "pytorch": + _init_dist_pytorch(backend, **kwargs) + elif launcher == "slurm": + _init_dist_slurm(backend, **kwargs) + else: + raise ValueError(f"Invalid launcher type: {launcher}") + + +def _init_dist_pytorch(backend, **kwargs): + rank = int(os.environ["RANK"]) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_slurm(backend, port=None): + """Initialize slurm distributed training environment. + + If argument ``port`` is not specified, then the master port will be system + environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system + environment variable, then a default port ``29500`` will be used. + + Args: + backend (str): Backend of torch.distributed. + port (int, optional): Master port. Defaults to None. + """ + proc_id = int(os.environ["SLURM_PROCID"]) + ntasks = int(os.environ["SLURM_NTASKS"]) + node_list = os.environ["SLURM_NODELIST"] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1") + # specify master port + if port is not None: + os.environ["MASTER_PORT"] = str(port) + elif "MASTER_PORT" in os.environ: + pass # use MASTER_PORT in the environment variable + else: + # 29500 is torch.distributed default port + os.environ["MASTER_PORT"] = "29500" + os.environ["MASTER_ADDR"] = addr + os.environ["WORLD_SIZE"] = str(ntasks) + os.environ["LOCAL_RANK"] = str(proc_id % num_gpus) + os.environ["RANK"] = str(proc_id) + dist.init_process_group(backend=backend) + + +def get_dist_info(): + if dist.is_available(): + initialized = dist.is_initialized() + else: + initialized = False + if initialized: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + return rank, world_size + + +def master_only(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank, _ = get_dist_info() + if rank == 0: + return func(*args, **kwargs) + + return wrapper diff --git a/hordelib/nodes/facerestore_cf/basicsr/utils/download_util.py b/hordelib/nodes/facerestore_cf/basicsr/utils/download_util.py new file mode 100644 index 00000000..620c9f73 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/utils/download_util.py @@ -0,0 +1,83 @@ +import math +import os +from urllib.parse import urlparse + +import requests +from torch.hub import download_url_to_file, get_dir +from tqdm import tqdm + +from .misc import sizeof_fmt +from hordelib.shared_model_manager import SharedModelManager + + +def download_file_from_google_drive(file_id, save_path): + """Download files from google drive. + Ref: + https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 + Args: + file_id (str): File id. + save_path (str): Save path. + """ + + session = requests.Session() + URL = "https://docs.google.com/uc?export=download" + params = {"id": file_id} + + response = session.get(URL, params=params, stream=True) + token = get_confirm_token(response) + if token: + params["confirm"] = token + response = session.get(URL, params=params, stream=True) + + # get file size + response_file_size = session.get(URL, params=params, stream=True, headers={"Range": "bytes=0-2"}) + print(response_file_size) + if "Content-Range" in response_file_size.headers: + file_size = int(response_file_size.headers["Content-Range"].split("/")[1]) + else: + file_size = None + + save_response_content(response, save_path, file_size) + + +def get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith("download_warning"): + return value + return None + + +def save_response_content(response, destination, file_size=None, chunk_size=32768): + if file_size is not None: + pbar = tqdm(total=math.ceil(file_size / chunk_size), unit="chunk") + + readable_file_size = sizeof_fmt(file_size) + else: + pbar = None + + with open(destination, "wb") as f: + downloaded_size = 0 + for chunk in response.iter_content(chunk_size): + downloaded_size += chunk_size + if pbar is not None: + pbar.update(1) + pbar.set_description(f"Download {sizeof_fmt(downloaded_size)} / {readable_file_size}") + if chunk: # filter out keep-alive new chunks + f.write(chunk) + if pbar is not None: + pbar.close() + + +def load_file_from_url(url, model_dir=None, progress=True, file_name=None): + """Load file form http url, will download models if necessary. + Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py + Args: + url (str): URL to be downloaded. + model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. + Default: None. + progress (bool): Whether to show the download progress. Default: True. + file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. + Returns: + str: The path to the downloaded file. + """ + return str(SharedModelManager.manager.gfpgan.model_folder_path / file_name) diff --git a/hordelib/nodes/facerestore_cf/basicsr/utils/file_client.py b/hordelib/nodes/facerestore_cf/basicsr/utils/file_client.py new file mode 100644 index 00000000..d9c1d273 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/utils/file_client.py @@ -0,0 +1,172 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py +from abc import ABCMeta, abstractmethod + + +class BaseStorageBackend(metaclass=ABCMeta): + """Abstract class of storage backends. + + All backends need to implement two apis: ``get()`` and ``get_text()``. + ``get()`` reads the file as a byte stream and ``get_text()`` reads the file + as texts. + """ + + @abstractmethod + def get(self, filepath): + pass + + @abstractmethod + def get_text(self, filepath): + pass + + +class MemcachedBackend(BaseStorageBackend): + """Memcached storage backend. + + Attributes: + server_list_cfg (str): Config file for memcached server list. + client_cfg (str): Config file for memcached client. + sys_path (str | None): Additional path to be appended to `sys.path`. + Default: None. + """ + + def __init__(self, server_list_cfg, client_cfg, sys_path=None): + if sys_path is not None: + import sys + + sys.path.append(sys_path) + try: + import mc + except ImportError: + raise ImportError("Please install memcached to enable MemcachedBackend.") + + self.server_list_cfg = server_list_cfg + self.client_cfg = client_cfg + self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) + # mc.pyvector servers as a point which points to a memory cache + self._mc_buffer = mc.pyvector() + + def get(self, filepath): + filepath = str(filepath) + import mc + + self._client.Get(filepath, self._mc_buffer) + value_buf = mc.ConvertBuffer(self._mc_buffer) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class HardDiskBackend(BaseStorageBackend): + """Raw hard disks storage backend.""" + + def get(self, filepath): + filepath = str(filepath) + with open(filepath, "rb") as f: + value_buf = f.read() + return value_buf + + def get_text(self, filepath): + filepath = str(filepath) + with open(filepath) as f: + value_buf = f.read() + return value_buf + + +class LmdbBackend(BaseStorageBackend): + """Lmdb storage backend. + + Args: + db_paths (str | list[str]): Lmdb database paths. + client_keys (str | list[str]): Lmdb client keys. Default: 'default'. + readonly (bool, optional): Lmdb environment parameter. If True, + disallow any write operations. Default: True. + lock (bool, optional): Lmdb environment parameter. If False, when + concurrent access occurs, do not lock the database. Default: False. + readahead (bool, optional): Lmdb environment parameter. If False, + disable the OS filesystem readahead mechanism, which may improve + random read performance when a database is larger than RAM. + Default: False. + + Attributes: + db_paths (list): Lmdb database path. + _client (list): A list of several lmdb envs. + """ + + def __init__(self, db_paths, client_keys="default", readonly=True, lock=False, readahead=False, **kwargs): + try: + import lmdb + except ImportError: + raise ImportError("Please install lmdb to enable LmdbBackend.") + + if isinstance(client_keys, str): + client_keys = [client_keys] + + if isinstance(db_paths, list): + self.db_paths = [str(v) for v in db_paths] + elif isinstance(db_paths, str): + self.db_paths = [str(db_paths)] + assert len(client_keys) == len(self.db_paths), ( + "client_keys and db_paths should have the same length, " + f"but received {len(client_keys)} and {len(self.db_paths)}." + ) + + self._client = {} + for client, path in zip(client_keys, self.db_paths, strict=False): + self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) + + def get(self, filepath, client_key): + """Get values according to the filepath from one lmdb named client_key. + + Args: + filepath (str | obj:`Path`): Here, filepath is the lmdb key. + client_key (str): Used for distinguishing differnet lmdb envs. + """ + filepath = str(filepath) + assert client_key in self._client, f"client_key {client_key} is not " "in lmdb clients." + client = self._client[client_key] + with client.begin(write=False) as txn: + value_buf = txn.get(filepath.encode("ascii")) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class FileClient: + """A general file client to access files in different backend. + + The client loads a file or text in a specified backend from its path + and return it as a binary file. it can also register other backend + accessor with a given name and backend class. + + Attributes: + backend (str): The storage backend type. Options are "disk", + "memcached" and "lmdb". + client (:obj:`BaseStorageBackend`): The backend object. + """ + + _backends = { + "disk": HardDiskBackend, + "memcached": MemcachedBackend, + "lmdb": LmdbBackend, + } + + def __init__(self, backend="disk", **kwargs): + if backend not in self._backends: + raise ValueError( + f"Backend {backend} is not supported. Currently supported ones" f" are {list(self._backends.keys())}", + ) + self.backend = backend + self.client = self._backends[backend](**kwargs) + + def get(self, filepath, client_key="default"): + # client_key is used only for lmdb, where different fileclients have + # different lmdb environments. + if self.backend == "lmdb": + return self.client.get(filepath, client_key) + else: + return self.client.get(filepath) + + def get_text(self, filepath): + return self.client.get_text(filepath) diff --git a/hordelib/nodes/facerestore_cf/basicsr/utils/img_util.py b/hordelib/nodes/facerestore_cf/basicsr/utils/img_util.py new file mode 100644 index 00000000..da64b7aa --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/utils/img_util.py @@ -0,0 +1,171 @@ +import math +import os + +import cv2 +import numpy as np +import torch +from torchvision.utils import make_grid + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == "float64": + img = img.astype("float32") + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): + """Convert torch Tensors into image numpy arrays. + + After clamping to [min, max], values will be normalized to [0, 1]. + + Args: + tensor (Tensor or list[Tensor]): Accept shapes: + 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); + 2) 3D Tensor of shape (3/1 x H x W); + 3) 2D Tensor of shape (H x W). + Tensor channel should be in RGB order. + rgb2bgr (bool): Whether to change rgb to bgr. + out_type (numpy type): output types. If ``np.uint8``, transform outputs + to uint8 type with range [0, 255]; otherwise, float type with + range [0, 1]. Default: ``np.uint8``. + min_max (tuple[int]): min and max values for clamp. + + Returns: + (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of + shape (H x W). The channel order is BGR. + """ + if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): + raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}") + + if torch.is_tensor(tensor): + tensor = [tensor] + result = [] + for _tensor in tensor: + _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) + _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) + + n_dim = _tensor.dim() + if n_dim == 4: + img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() + img_np = img_np.transpose(1, 2, 0) + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 3: + img_np = _tensor.numpy() + img_np = img_np.transpose(1, 2, 0) + if img_np.shape[2] == 1: # gray image + img_np = np.squeeze(img_np, axis=2) + else: + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 2: + img_np = _tensor.numpy() + else: + raise TypeError("Only support 4D, 3D or 2D tensor. " f"But received with dimension: {n_dim}") + if out_type == np.uint8: + # Unlike MATLAB, numpy.unit8() WILL NOT round by default. + img_np = (img_np * 255.0).round() + img_np = img_np.astype(out_type) + result.append(img_np) + if len(result) == 1: + result = result[0] + return result + + +def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)): + """This implementation is slightly faster than tensor2img. + It now only supports torch tensor with shape (1, c, h, w). + + Args: + tensor (Tensor): Now only support torch tensor with (1, c, h, w). + rgb2bgr (bool): Whether to change rgb to bgr. Default: True. + min_max (tuple[int]): min and max values for clamp. + """ + output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0) + output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255 + output = output.type(torch.uint8).cpu().numpy() + if rgb2bgr: + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output + + +def imfrombytes(content, flag="color", float32=False): + """Read an image from bytes. + + Args: + content (bytes): Image bytes got from files or other streams. + flag (str): Flags specifying the color type of a loaded image, + candidates are `color`, `grayscale` and `unchanged`. + float32 (bool): Whether to change to float32., If True, will also norm + to [0, 1]. Default: False. + + Returns: + ndarray: Loaded image array. + """ + img_np = np.frombuffer(content, np.uint8) + imread_flags = {"color": cv2.IMREAD_COLOR, "grayscale": cv2.IMREAD_GRAYSCALE, "unchanged": cv2.IMREAD_UNCHANGED} + img = cv2.imdecode(img_np, imread_flags[flag]) + if float32: + img = img.astype(np.float32) / 255.0 + return img + + +def imwrite(img, file_path, params=None, auto_mkdir=True): + """Write image to file. + + Args: + img (ndarray): Image array to be written. + file_path (str): Image file path. + params (None or list): Same as opencv's :func:`imwrite` interface. + auto_mkdir (bool): If the parent folder of `file_path` does not exist, + whether to create it automatically. + + Returns: + bool: Successful or not. + """ + if auto_mkdir: + dir_name = os.path.abspath(os.path.dirname(file_path)) + os.makedirs(dir_name, exist_ok=True) + return cv2.imwrite(file_path, img, params) + + +def crop_border(imgs, crop_border): + """Crop borders of images. + + Args: + imgs (list[ndarray] | ndarray): Images with shape (h, w, c). + crop_border (int): Crop border for each end of height and weight. + + Returns: + list[ndarray]: Cropped images. + """ + if crop_border == 0: + return imgs + else: + if isinstance(imgs, list): + return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs] + else: + return imgs[crop_border:-crop_border, crop_border:-crop_border, ...] diff --git a/hordelib/nodes/facerestore_cf/basicsr/utils/lmdb_util.py b/hordelib/nodes/facerestore_cf/basicsr/utils/lmdb_util.py new file mode 100644 index 00000000..21ecbfd5 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/utils/lmdb_util.py @@ -0,0 +1,200 @@ +import sys +from multiprocessing import Pool +from os import path as osp + +import cv2 +import lmdb +from tqdm import tqdm + + +def make_lmdb_from_imgs( + data_path, + lmdb_path, + img_path_list, + keys, + batch=5000, + compress_level=1, + multiprocessing_read=False, + n_thread=40, + map_size=None, +): + """Make lmdb from images. + + Contents of lmdb. The file structure is: + example.lmdb + ├── data.mdb + ├── lock.mdb + ├── meta_info.txt + + The data.mdb and lock.mdb are standard lmdb files and you can refer to + https://lmdb.readthedocs.io/en/release/ for more details. + + The meta_info.txt is a specified txt file to record the meta information + of our datasets. It will be automatically created when preparing + datasets by our provided dataset tools. + Each line in the txt file records 1)image name (with extension), + 2)image shape, and 3)compression level, separated by a white space. + + For example, the meta information could be: + `000_00000000.png (720,1280,3) 1`, which means: + 1) image name (with extension): 000_00000000.png; + 2) image shape: (720,1280,3); + 3) compression level: 1 + + We use the image name without extension as the lmdb key. + + If `multiprocessing_read` is True, it will read all the images to memory + using multiprocessing. Thus, your server needs to have enough memory. + + Args: + data_path (str): Data path for reading images. + lmdb_path (str): Lmdb save path. + img_path_list (str): Image path list. + keys (str): Used for lmdb keys. + batch (int): After processing batch images, lmdb commits. + Default: 5000. + compress_level (int): Compress level when encoding images. Default: 1. + multiprocessing_read (bool): Whether use multiprocessing to read all + the images to memory. Default: False. + n_thread (int): For multiprocessing. + map_size (int | None): Map size for lmdb env. If None, use the + estimated size from images. Default: None + """ + + assert len(img_path_list) == len(keys), ( + "img_path_list and keys should have the same length, " f"but got {len(img_path_list)} and {len(keys)}" + ) + print(f"Create lmdb for {data_path}, save to {lmdb_path}...") + print(f"Totoal images: {len(img_path_list)}") + if not lmdb_path.endswith(".lmdb"): + raise ValueError("lmdb_path must end with '.lmdb'.") + if osp.exists(lmdb_path): + print(f"Folder {lmdb_path} already exists. Exit.") + sys.exit(1) + + if multiprocessing_read: + # read all the images to memory (multiprocessing) + dataset = {} # use dict to keep the order for multiprocessing + shapes = {} + print(f"Read images with multiprocessing, #thread: {n_thread} ...") + pbar = tqdm(total=len(img_path_list), unit="image") + + def callback(arg): + """get the image data and update pbar.""" + key, dataset[key], shapes[key] = arg + pbar.update(1) + pbar.set_description(f"Read {key}") + + pool = Pool(n_thread) + for path, key in zip(img_path_list, keys, strict=False): + pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback) + pool.close() + pool.join() + pbar.close() + print(f"Finish reading {len(img_path_list)} images.") + + # create lmdb environment + if map_size is None: + # obtain data size for one image + img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) + _, img_byte = cv2.imencode(".png", img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + data_size_per_img = img_byte.nbytes + print("Data size per image is: ", data_size_per_img) + data_size = data_size_per_img * len(img_path_list) + map_size = data_size * 10 + + env = lmdb.open(lmdb_path, map_size=map_size) + + # write data to lmdb + pbar = tqdm(total=len(img_path_list), unit="chunk") + txn = env.begin(write=True) + txt_file = open(osp.join(lmdb_path, "meta_info.txt"), "w") + for idx, (path, key) in enumerate(zip(img_path_list, keys, strict=False)): + pbar.update(1) + pbar.set_description(f"Write {key}") + key_byte = key.encode("ascii") + if multiprocessing_read: + img_byte = dataset[key] + h, w, c = shapes[key] + else: + _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level) + h, w, c = img_shape + + txn.put(key_byte, img_byte) + # write meta information + txt_file.write(f"{key}.png ({h},{w},{c}) {compress_level}\n") + if idx % batch == 0: + txn.commit() + txn = env.begin(write=True) + pbar.close() + txn.commit() + env.close() + txt_file.close() + print("\nFinish writing lmdb.") + + +def read_img_worker(path, key, compress_level): + """Read image worker. + + Args: + path (str): Image path. + key (str): Image key. + compress_level (int): Compress level when encoding images. + + Returns: + str: Image key. + byte: Image byte. + tuple[int]: Image shape. + """ + + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + if img.ndim == 2: + h, w = img.shape + c = 1 + else: + h, w, c = img.shape + _, img_byte = cv2.imencode(".png", img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + return (key, img_byte, (h, w, c)) + + +class LmdbMaker: + """LMDB Maker. + + Args: + lmdb_path (str): Lmdb save path. + map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB. + batch (int): After processing batch images, lmdb commits. + Default: 5000. + compress_level (int): Compress level when encoding images. Default: 1. + """ + + def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1): + if not lmdb_path.endswith(".lmdb"): + raise ValueError("lmdb_path must end with '.lmdb'.") + if osp.exists(lmdb_path): + print(f"Folder {lmdb_path} already exists. Exit.") + sys.exit(1) + + self.lmdb_path = lmdb_path + self.batch = batch + self.compress_level = compress_level + self.env = lmdb.open(lmdb_path, map_size=map_size) + self.txn = self.env.begin(write=True) + self.txt_file = open(osp.join(lmdb_path, "meta_info.txt"), "w") + self.counter = 0 + + def put(self, img_byte, key, img_shape): + self.counter += 1 + key_byte = key.encode("ascii") + self.txn.put(key_byte, img_byte) + # write meta information + h, w, c = img_shape + self.txt_file.write(f"{key}.png ({h},{w},{c}) {self.compress_level}\n") + if self.counter % self.batch == 0: + self.txn.commit() + self.txn = self.env.begin(write=True) + + def close(self): + self.txn.commit() + self.env.close() + self.txt_file.close() diff --git a/hordelib/nodes/facerestore_cf/basicsr/utils/logger.py b/hordelib/nodes/facerestore_cf/basicsr/utils/logger.py new file mode 100644 index 00000000..1201c879 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/utils/logger.py @@ -0,0 +1,174 @@ +import datetime +import logging +import time + +from .dist_util import get_dist_info, master_only + +initialized_logger = {} + + +class MessageLogger: + """Message logger for printing. + Args: + opt (dict): Config. It contains the following keys: + name (str): Exp name. + logger (dict): Contains 'print_freq' (str) for logger interval. + train (dict): Contains 'total_iter' (int) for total iters. + use_tb_logger (bool): Use tensorboard logger. + start_iter (int): Start iter. Default: 1. + tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. + """ + + def __init__(self, opt, start_iter=1, tb_logger=None): + self.exp_name = opt["name"] + self.interval = opt["logger"]["print_freq"] + self.start_iter = start_iter + self.max_iters = opt["train"]["total_iter"] + self.use_tb_logger = opt["logger"]["use_tb_logger"] + self.tb_logger = tb_logger + self.start_time = time.time() + self.logger = get_root_logger() + + @master_only + def __call__(self, log_vars): + """Format logging message. + Args: + log_vars (dict): It contains the following keys: + epoch (int): Epoch number. + iter (int): Current iter. + lrs (list): List for learning rates. + time (float): Iter time. + data_time (float): Data time for each iter. + """ + # epoch, iter, learning rates + epoch = log_vars.pop("epoch") + current_iter = log_vars.pop("iter") + lrs = log_vars.pop("lrs") + + message = f"[{self.exp_name[:5]}..][epoch:{epoch:3d}, " f"iter:{current_iter:8,d}, lr:(" + for v in lrs: + message += f"{v:.3e}," + message += ")] " + + # time and estimated time + if "time" in log_vars.keys(): + iter_time = log_vars.pop("time") + data_time = log_vars.pop("data_time") + + total_time = time.time() - self.start_time + time_sec_avg = total_time / (current_iter - self.start_iter + 1) + eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) + eta_str = str(datetime.timedelta(seconds=int(eta_sec))) + message += f"[eta: {eta_str}, " + message += f"time (data): {iter_time:.3f} ({data_time:.3f})] " + + # other items, especially losses + for k, v in log_vars.items(): + message += f"{k}: {v:.4e} " + # tensorboard logger + if self.use_tb_logger: + if k.startswith("l_"): + self.tb_logger.add_scalar(f"losses/{k}", v, current_iter) + else: + self.tb_logger.add_scalar(k, v, current_iter) + self.logger.info(message) + + +@master_only +def init_tb_logger(log_dir): + from torch.utils.tensorboard import SummaryWriter + + tb_logger = SummaryWriter(log_dir=log_dir) + return tb_logger + + +@master_only +def init_wandb_logger(opt): + """We now only use wandb to sync tensorboard log.""" + import wandb + + logger = logging.getLogger("basicsr") + + project = opt["logger"]["wandb"]["project"] + resume_id = opt["logger"]["wandb"].get("resume_id") + if resume_id: + wandb_id = resume_id + resume = "allow" + logger.warning(f"Resume wandb logger with id={wandb_id}.") + else: + wandb_id = wandb.util.generate_id() + resume = "never" + + wandb.init(id=wandb_id, resume=resume, name=opt["name"], config=opt, project=project, sync_tensorboard=True) + + logger.info(f"Use wandb logger with id={wandb_id}; project={project}.") + + +def get_root_logger(logger_name="basicsr", log_level=logging.INFO, log_file=None): + """Get the root logger. + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. + Args: + logger_name (str): root logger name. Default: 'basicsr'. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the root logger. + log_level (int): The root logger level. Note that only the process of + rank 0 is affected, while other processes will set the level to + "Error" and be silent most of the time. + Returns: + logging.Logger: The root logger. + """ + logger = logging.getLogger(logger_name) + # if the logger has been initialized, just return it + if logger_name in initialized_logger: + return logger + + format_str = "%(asctime)s %(levelname)s: %(message)s" + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(logging.Formatter(format_str)) + logger.addHandler(stream_handler) + logger.propagate = False + rank, _ = get_dist_info() + if rank != 0: + logger.setLevel("ERROR") + elif log_file is not None: + logger.setLevel(log_level) + # add file handler + # file_handler = logging.FileHandler(log_file, 'w') + file_handler = logging.FileHandler(log_file, "a") # Shangchen: keep the previous log + file_handler.setFormatter(logging.Formatter(format_str)) + file_handler.setLevel(log_level) + logger.addHandler(file_handler) + initialized_logger[logger_name] = True + return logger + + +def get_env_info(): + """Get environment information. + Currently, only log the software version. + """ + import torch + import torchvision + + from hordelib.nodes.facerestore_cf.basicsr.version import __version__ + + msg = r""" + ____ _ _____ ____ + / __ ) ____ _ _____ (_)_____/ ___/ / __ \ + / __ |/ __ `// ___// // ___/\__ \ / /_/ / + / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/ + /_____/ \__,_//____//_/ \___//____//_/ |_| + ______ __ __ __ __ + / ____/____ ____ ____/ / / / __ __ _____ / /__ / / + / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / / + / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/ + \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_) + """ + msg += ( + "\nVersion Information: " + f"\n\tBasicSR: {__version__}" + f"\n\tPyTorch: {torch.__version__}" + f"\n\tTorchVision: {torchvision.__version__}" + ) + return msg diff --git a/hordelib/nodes/facerestore_cf/basicsr/utils/matlab_functions.py b/hordelib/nodes/facerestore_cf/basicsr/utils/matlab_functions.py new file mode 100644 index 00000000..5a637988 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/utils/matlab_functions.py @@ -0,0 +1,372 @@ +import math + +import numpy as np +import torch + + +def cubic(x): + """cubic function used for calculate_weights_indices.""" + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + ( + -0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2 + ) * (((absx > 1) * (absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + """Calculate weights and indices, used for imresize function. + + Args: + in_length (int): Input length. + out_length (int): Output length. + scale (float): Scale factor. + kernel_width (int): Kernel width. + antialisaing (bool): Whether to apply anti-aliasing when downsampling. + """ + + if (scale < 1) and antialiasing: + # Use a modified kernel (larger kernel width) to simultaneously + # interpolate and antialias + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5 + scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + p = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand( + out_length, + p, + ) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices + + # apply cubic kernel + if (scale < 1) and antialiasing: + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, p) + + # If a column in weights is all zero, get rid of it. only consider the + # first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, p - 2) + weights = weights.narrow(1, 1, p - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, p - 2) + weights = weights.narrow(1, 0, p - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +@torch.no_grad() +def imresize(img, scale, antialiasing=True): + """imresize function same as MATLAB. + + It now only supports bicubic. + The same scale applies for both height and width. + + Args: + img (Tensor | Numpy array): + Tensor: Input image with shape (c, h, w), [0, 1] range. + Numpy: Input image with shape (h, w, c), [0, 1] range. + scale (float): Scale factor. The same scale applies for both height + and width. + antialisaing (bool): Whether to apply anti-aliasing when downsampling. + Default: True. + + Returns: + Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round. + """ + if type(img).__module__ == np.__name__: # numpy type + numpy_type = True + img = torch.from_numpy(img.transpose(2, 0, 1)).float() + else: + numpy_type = False + + in_c, in_h, in_w = img.size() + out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale) + kernel_width = 4 + kernel = "cubic" + + # get weights and indices + weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices( + in_h, + out_h, + scale, + kernel, + kernel_width, + antialiasing, + ) + weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices( + in_w, + out_w, + scale, + kernel, + kernel_width, + antialiasing, + ) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w) + img_aug.narrow(1, sym_len_hs, in_h).copy_(img) + + sym_patch = img[:, :sym_len_hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_he:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_c, out_h, in_w) + kernel_width = weights_h.size(1) + for i in range(out_h): + idx = int(indices_h[i][0]) + for j in range(in_c): + out_1[j, i, :] = img_aug[j, idx : idx + kernel_width, :].transpose(0, 1).mv(weights_h[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we) + out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_we:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_c, out_h, out_w) + kernel_width = weights_w.size(1) + for i in range(out_w): + idx = int(indices_w[i][0]) + for j in range(in_c): + out_2[j, :, i] = out_1_aug[j, :, idx : idx + kernel_width].mv(weights_w[i]) + + if numpy_type: + out_2 = out_2.numpy().transpose(1, 2, 0) + return out_2 + + +def rgb2ycbcr(img, y_only=False): + """Convert a RGB image to YCbCr image. + + This function produces the same results as Matlab's `rgb2ycbcr` function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0 + else: + out_img = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [ + 16, + 128, + 128, + ] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def bgr2ycbcr(img, y_only=False): + """Convert a BGR image to YCbCr image. + + The bgr version of rgb2ycbcr. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 + else: + out_img = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [ + 16, + 128, + 128, + ] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2rgb(img): + """Convert a YCbCr image to RGB image. + + This function produces the same results as Matlab's ycbcr2rgb function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted RGB image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul( + img, + [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], [0.00625893, -0.00318811, 0]], + ) * 255.0 + [-222.921, 135.576, -276.836] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2bgr(img): + """Convert a YCbCr image to BGR image. + + The bgr version of ycbcr2rgb. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted BGR image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul( + img, + [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0], [0, -0.00318811, 0.00625893]], + ) * 255.0 + [-276.836, 135.576, -222.921] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def _convert_input_type_range(img): + """Convert the type and range of the input image. + + It converts the input image to np.float32 type and range of [0, 1]. + It is mainly used for pre-processing the input image in colorspace + convertion functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + (ndarray): The converted image with type of np.float32 and range of + [0, 1]. + """ + img_type = img.dtype + img = img.astype(np.float32) + if img_type == np.float32: + pass + elif img_type == np.uint8: + img /= 255.0 + else: + raise TypeError("The img type should be np.float32 or np.uint8, " f"but got {img_type}") + return img + + +def _convert_output_type_range(img, dst_type): + """Convert the type and range of the image according to dst_type. + + It converts the image to desired type and range. If `dst_type` is np.uint8, + images will be converted to np.uint8 type with range [0, 255]. If + `dst_type` is np.float32, it converts the image to np.float32 type with + range [0, 1]. + It is mainly used for post-processing images in colorspace convertion + functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The image to be converted with np.float32 type and + range [0, 255]. + dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it + converts the image to np.uint8 type with range [0, 255]. If + dst_type is np.float32, it converts the image to np.float32 type + with range [0, 1]. + + Returns: + (ndarray): The converted image with desired type and range. + """ + if dst_type not in (np.uint8, np.float32): + raise TypeError("The dst_type should be np.float32 or np.uint8, " f"but got {dst_type}") + if dst_type == np.uint8: + img = img.round() + else: + img /= 255.0 + return img.astype(dst_type) diff --git a/hordelib/nodes/facerestore_cf/basicsr/utils/misc.py b/hordelib/nodes/facerestore_cf/basicsr/utils/misc.py new file mode 100644 index 00000000..be8fe502 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/utils/misc.py @@ -0,0 +1,136 @@ +import os +import random +import time +from os import path as osp + +import numpy as np +import torch + +from .dist_util import master_only +from .logger import get_root_logger + + +def set_random_seed(seed): + """Set random seeds.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_time_str(): + return time.strftime("%Y%m%d_%H%M%S", time.localtime()) + + +def mkdir_and_rename(path): + """mkdirs. If path exists, rename it with timestamp and create a new one. + + Args: + path (str): Folder path. + """ + if osp.exists(path): + new_name = path + "_archived_" + get_time_str() + print(f"Path already exists. Rename it to {new_name}", flush=True) + os.rename(path, new_name) + os.makedirs(path, exist_ok=True) + + +@master_only +def make_exp_dirs(opt): + """Make dirs for experiments.""" + path_opt = opt["path"].copy() + if opt["is_train"]: + mkdir_and_rename(path_opt.pop("experiments_root")) + else: + mkdir_and_rename(path_opt.pop("results_root")) + for key, path in path_opt.items(): + if ("strict_load" not in key) and ("pretrain_network" not in key) and ("resume" not in key): + os.makedirs(path, exist_ok=True) + + +def scandir(dir_path, suffix=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + + Returns: + A generator for all the interested files with relative pathes. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith(".") and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + else: + if recursive: + yield from _scandir(entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) + + +def check_resume(opt, resume_iter): + """Check resume states and pretrain_network paths. + + Args: + opt (dict): Options. + resume_iter (int): Resume iteration. + """ + logger = get_root_logger() + if opt["path"]["resume_state"]: + # get all the networks + networks = [key for key in opt.keys() if key.startswith("network_")] + flag_pretrain = False + for network in networks: + if opt["path"].get(f"pretrain_{network}") is not None: + flag_pretrain = True + if flag_pretrain: + logger.warning("pretrain_network path will be ignored during resuming.") + # set pretrained model paths + for network in networks: + name = f"pretrain_{network}" + basename = network.replace("network_", "") + if opt["path"].get("ignore_resume_networks") is None or ( + basename not in opt["path"]["ignore_resume_networks"] + ): + opt["path"][name] = osp.join(opt["path"]["models"], f"net_{basename}_{resume_iter}.pth") + logger.info(f"Set {name} to {opt['path'][name]}") + + +def sizeof_fmt(size, suffix="B"): + """Get human readable file size. + + Args: + size (int): File size. + suffix (str): Suffix. Default: 'B'. + + Return: + str: Formated file siz. + """ + for unit in ["", "K", "M", "G", "T", "P", "E", "Z"]: + if abs(size) < 1024.0: + return f"{size:3.1f} {unit}{suffix}" + size /= 1024.0 + return f"{size:3.1f} Y{suffix}" diff --git a/hordelib/nodes/facerestore_cf/basicsr/utils/options.py b/hordelib/nodes/facerestore_cf/basicsr/utils/options.py new file mode 100644 index 00000000..909e4576 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/utils/options.py @@ -0,0 +1,109 @@ +from collections import OrderedDict +from os import path as osp + +import yaml + +from hordelib.nodes.facerestore_cf.basicsr.utils.misc import get_time_str + + +def ordered_yaml(): + """Support OrderedDict for yaml. + + Returns: + yaml Loader and Dumper. + """ + try: + from yaml import CDumper as Dumper + from yaml import CLoader as Loader + except ImportError: + from yaml import Dumper, Loader + + _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG + + def dict_representer(dumper, data): + return dumper.represent_dict(data.items()) + + def dict_constructor(loader, node): + return OrderedDict(loader.construct_pairs(node)) + + Dumper.add_representer(OrderedDict, dict_representer) + Loader.add_constructor(_mapping_tag, dict_constructor) + return Loader, Dumper + + +def parse(opt_path, root_path, is_train=True): + """Parse option file. + + Args: + opt_path (str): Option file path. + is_train (str): Indicate whether in training or not. Default: True. + + Returns: + (dict): Options. + """ + with open(opt_path) as f: + Loader, _ = ordered_yaml() + opt = yaml.load(f, Loader=Loader) + + opt["is_train"] = is_train + + # opt['name'] = f"{get_time_str()}_{opt['name']}" + if opt["path"].get("resume_state", None): # Shangchen added + resume_state_path = opt["path"].get("resume_state") + opt["name"] = resume_state_path.split("/")[-3] + else: + opt["name"] = f"{get_time_str()}_{opt['name']}" + + # datasets + for phase, dataset in opt["datasets"].items(): + # for several datasets, e.g., test_1, test_2 + phase = phase.split("_")[0] + dataset["phase"] = phase + if "scale" in opt: + dataset["scale"] = opt["scale"] + if dataset.get("dataroot_gt") is not None: + dataset["dataroot_gt"] = osp.expanduser(dataset["dataroot_gt"]) + if dataset.get("dataroot_lq") is not None: + dataset["dataroot_lq"] = osp.expanduser(dataset["dataroot_lq"]) + + # paths + for key, val in opt["path"].items(): + if (val is not None) and ("resume_state" in key or "pretrain_network" in key): + opt["path"][key] = osp.expanduser(val) + + if is_train: + experiments_root = osp.join(root_path, "experiments", opt["name"]) + opt["path"]["experiments_root"] = experiments_root + opt["path"]["models"] = osp.join(experiments_root, "models") + opt["path"]["training_states"] = osp.join(experiments_root, "training_states") + opt["path"]["log"] = experiments_root + opt["path"]["visualization"] = osp.join(experiments_root, "visualization") + + else: # test + results_root = osp.join(root_path, "results", opt["name"]) + opt["path"]["results_root"] = results_root + opt["path"]["log"] = results_root + opt["path"]["visualization"] = osp.join(results_root, "visualization") + + return opt + + +def dict2str(opt, indent_level=1): + """dict to string for printing options. + + Args: + opt (dict): Option dict. + indent_level (int): Indent level. Default: 1. + + Return: + (str): Option string for printing. + """ + msg = "\n" + for k, v in opt.items(): + if isinstance(v, dict): + msg += " " * (indent_level * 2) + k + ":[" + msg += dict2str(v, indent_level + 1) + msg += " " * (indent_level * 2) + "]\n" + else: + msg += " " * (indent_level * 2) + k + ": " + str(v) + "\n" + return msg diff --git a/hordelib/nodes/facerestore_cf/basicsr/utils/realesrgan_utils.py b/hordelib/nodes/facerestore_cf/basicsr/utils/realesrgan_utils.py new file mode 100644 index 00000000..73c4da85 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/utils/realesrgan_utils.py @@ -0,0 +1,319 @@ +import math +import os +import queue +import threading + +import cv2 +import numpy as np +import torch +from torch.nn import functional as F + +from hordelib.nodes.facerestore_cf.basicsr.utils.download_util import load_file_from_url + +# ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +class RealESRGANer: + """A helper class for upsampling images with RealESRGAN. + + Args: + scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4. + model_path (str): The path to the pretrained model. It can be urls (will first download it automatically). + model (nn.Module): The defined network. Default: None. + tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop + input images into tiles, and then process each of them. Finally, they will be merged into one image. + 0 denotes for do not use tile. Default: 0. + tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10. + pre_pad (int): Pad the input images to avoid border artifacts. Default: 10. + half (float): Whether to use half precision during inference. Default: False. + """ + + def __init__( + self, + scale, + model_path, + model=None, + tile=0, + tile_pad=10, + pre_pad=10, + half=False, + device=None, + gpu_id=None, + ): + self.scale = scale + self.tile_size = tile + self.tile_pad = tile_pad + self.pre_pad = pre_pad + self.mod_scale = None + self.half = half + + # initialize model + if gpu_id: + self.device = ( + torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu") if device is None else device + ) + else: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device + # if the model_path starts with https, it will first download models to the folder: realesrgan/weights + if model_path.startswith("https://"): + model_path = load_file_from_url( + url=model_path, + model_dir=os.path.join("weights/realesrgan"), + progress=True, + file_name=None, + ) + loadnet = torch.load(model_path, map_location=torch.device("cpu")) + # prefer to use params_ema + if "params_ema" in loadnet: + keyname = "params_ema" + else: + keyname = "params" + model.load_state_dict(loadnet[keyname], strict=True) + model.eval() + self.model = model.to(self.device) + if self.half: + self.model = self.model.half() + + def pre_process(self, img): + """Pre-process, such as pre-pad and mod pad, so that the images can be divisible""" + img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() + self.img = img.unsqueeze(0).to(self.device) + if self.half: + self.img = self.img.half() + + # pre_pad + self.img_pre_pad = self.img.clone() + if self.pre_pad != 0: + self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), "reflect") + # mod pad for divisible borders + if self.scale == 2: + self.mod_scale = 2 + elif self.scale == 1: + self.mod_scale = 4 + if self.mod_scale is not None: + self.mod_pad_h, self.mod_pad_w = 0, 0 + _, _, h, w = self.img.size() + if h % self.mod_scale != 0: + self.mod_pad_h = self.mod_scale - h % self.mod_scale + if w % self.mod_scale != 0: + self.mod_pad_w = self.mod_scale - w % self.mod_scale + self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), "reflect") + + def process(self): + # model inference + self.output = self.model(self.img) + + def tile_process(self): + """It will first crop input images to tiles, and then process each tile. + Finally, all the processed tiles are merged into one images. + + Modified from: https://github.com/ata4/esrgan-launcher + """ + batch, channel, height, width = self.img.shape + output_height = height * self.scale + output_width = width * self.scale + output_shape = (batch, channel, output_height, output_width) + + # start with black image + self.output = self.img.new_zeros(output_shape) + tiles_x = math.ceil(width / self.tile_size) + tiles_y = math.ceil(height / self.tile_size) + + # loop over all tiles + for y in range(tiles_y): + for x in range(tiles_x): + # extract tile from input image + ofs_x = x * self.tile_size + ofs_y = y * self.tile_size + # input tile area on total image + input_start_x = ofs_x + input_end_x = min(ofs_x + self.tile_size, width) + input_start_y = ofs_y + input_end_y = min(ofs_y + self.tile_size, height) + + # input tile area on total image with padding + input_start_x_pad = max(input_start_x - self.tile_pad, 0) + input_end_x_pad = min(input_end_x + self.tile_pad, width) + input_start_y_pad = max(input_start_y - self.tile_pad, 0) + input_end_y_pad = min(input_end_y + self.tile_pad, height) + + # input tile dimensions + input_tile_width = input_end_x - input_start_x + input_tile_height = input_end_y - input_start_y + tile_idx = y * tiles_x + x + 1 + input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad] + + # upscale tile + try: + with torch.no_grad(): + output_tile = self.model(input_tile) + except RuntimeError as error: + print("Error", error) + # print(f'\tTile {tile_idx}/{tiles_x * tiles_y}') + + # output tile area on total image + output_start_x = input_start_x * self.scale + output_end_x = input_end_x * self.scale + output_start_y = input_start_y * self.scale + output_end_y = input_end_y * self.scale + + # output tile area without padding + output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale + output_end_x_tile = output_start_x_tile + input_tile_width * self.scale + output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale + output_end_y_tile = output_start_y_tile + input_tile_height * self.scale + + # put tile into output image + self.output[:, :, output_start_y:output_end_y, output_start_x:output_end_x] = output_tile[ + :, + :, + output_start_y_tile:output_end_y_tile, + output_start_x_tile:output_end_x_tile, + ] + + def post_process(self): + # remove extra pad + if self.mod_scale is not None: + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0 : h - self.mod_pad_h * self.scale, 0 : w - self.mod_pad_w * self.scale] + # remove prepad + if self.pre_pad != 0: + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0 : h - self.pre_pad * self.scale, 0 : w - self.pre_pad * self.scale] + return self.output + + @torch.no_grad() + def enhance(self, img, outscale=None, alpha_upsampler="realesrgan"): + h_input, w_input = img.shape[0:2] + # img: numpy + img = img.astype(np.float32) + if np.max(img) > 256: # 16-bit image + max_range = 65535 + print("\tInput is a 16-bit image") + else: + max_range = 255 + img = img / max_range + if len(img.shape) == 2: # gray image + img_mode = "L" + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + elif img.shape[2] == 4: # RGBA image with alpha channel + img_mode = "RGBA" + alpha = img[:, :, 3] + img = img[:, :, 0:3] + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + if alpha_upsampler == "realesrgan": + alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) + else: + img_mode = "RGB" + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + # ------------------- process image (without the alpha channel) ------------------- # + try: + with torch.no_grad(): + self.pre_process(img) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_img_t = self.post_process() + output_img = output_img_t.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0)) + if img_mode == "L": + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) + del output_img_t + torch.cuda.empty_cache() + except RuntimeError as error: + output_img = cv2.resize( + self.img_pre_pad, + (w_input * self.scale, h_input * self.scale), + interpolation=cv2.INTER_LINEAR, + ) + print(f"Failed inference for RealESRGAN: {error}") + + # ------------------- process the alpha channel if necessary ------------------- # + if img_mode == "RGBA": + if alpha_upsampler == "realesrgan": + self.pre_process(alpha) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_alpha = self.post_process() + output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) + output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) + else: # use the cv2 resize for alpha channel + h, w = alpha.shape[0:2] + output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR) + + # merge the alpha channel + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA) + output_img[:, :, 3] = output_alpha + + # ------------------------------ return ------------------------------ # + if max_range == 65535: # 16-bit image + output = (output_img * 65535.0).round().astype(np.uint16) + else: + output = (output_img * 255.0).round().astype(np.uint8) + + if outscale is not None and outscale != float(self.scale): + output = cv2.resize( + output, + ( + int(w_input * outscale), + int(h_input * outscale), + ), + interpolation=cv2.INTER_LANCZOS4, + ) + + return output, img_mode + + +class PrefetchReader(threading.Thread): + """Prefetch images. + + Args: + img_list (list[str]): A image list of image paths to be read. + num_prefetch_queue (int): Number of prefetch queue. + """ + + def __init__(self, img_list, num_prefetch_queue): + super().__init__() + self.que = queue.Queue(num_prefetch_queue) + self.img_list = img_list + + def run(self): + for img_path in self.img_list: + img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) + self.que.put(img) + + self.que.put(None) + + def __next__(self): + next_item = self.que.get() + if next_item is None: + raise StopIteration + return next_item + + def __iter__(self): + return self + + +class IOConsumer(threading.Thread): + + def __init__(self, opt, que, qid): + super().__init__() + self._queue = que + self.qid = qid + self.opt = opt + + def run(self): + while True: + msg = self._queue.get() + if isinstance(msg, str) and msg == "quit": + break + + output = msg["output"] + save_path = msg["save_path"] + cv2.imwrite(save_path, output) + print(f"IO worker {self.qid} is done.") diff --git a/hordelib/nodes/facerestore_cf/basicsr/utils/registry.py b/hordelib/nodes/facerestore_cf/basicsr/utils/registry.py new file mode 100644 index 00000000..1eec4a97 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/utils/registry.py @@ -0,0 +1,83 @@ +# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py + + +class Registry: + """ + The registry that provides name -> object mapping, to support third-party + users' custom modules. + + To create a registry (e.g. a backbone registry): + + .. code-block:: python + + BACKBONE_REGISTRY = Registry('BACKBONE') + + To register an object: + + .. code-block:: python + + @BACKBONE_REGISTRY.register() + class MyBackbone(): + ... + + Or: + + .. code-block:: python + + BACKBONE_REGISTRY.register(MyBackbone) + """ + + def __init__(self, name): + """ + Args: + name (str): the name of this registry + """ + self._name = name + self._obj_map = {} + + def _do_register(self, name, obj): + assert name not in self._obj_map, ( + f"An object named '{name}' was already registered " f"in '{self._name}' registry!" + ) + self._obj_map[name] = obj + + def register(self, obj=None): + """ + Register the given object under the the name `obj.__name__`. + Can be used as either a decorator or not. + See docstring of this class for usage. + """ + if obj is None: + # used as a decorator + def deco(func_or_class): + name = func_or_class.__name__ + self._do_register(name, func_or_class) + return func_or_class + + return deco + + # used as a function call + name = obj.__name__ + self._do_register(name, obj) + + def get(self, name): + ret = self._obj_map.get(name) + if ret is None: + raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") + return ret + + def __contains__(self, name): + return name in self._obj_map + + def __iter__(self): + return iter(self._obj_map.items()) + + def keys(self): + return self._obj_map.keys() + + +DATASET_REGISTRY = Registry("dataset") +ARCH_REGISTRY = Registry("arch") +MODEL_REGISTRY = Registry("model") +LOSS_REGISTRY = Registry("loss") +METRIC_REGISTRY = Registry("metric") diff --git a/hordelib/nodes/facerestore_cf/basicsr/version.py b/hordelib/nodes/facerestore_cf/basicsr/version.py new file mode 100644 index 00000000..677c4699 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/basicsr/version.py @@ -0,0 +1,5 @@ +# GENERATED VERSION FILE +# TIME: Sun Aug 7 15:14:26 2022 +__version__ = "1.3.2" +__gitsha__ = "6f94023" +version_info = (1, 3, 2) diff --git a/hordelib/nodes/facerestore/facelib/detection/yolov5face/__init__.py b/hordelib/nodes/facerestore_cf/facelib/__init__.py similarity index 100% rename from hordelib/nodes/facerestore/facelib/detection/yolov5face/__init__.py rename to hordelib/nodes/facerestore_cf/facelib/__init__.py diff --git a/hordelib/nodes/facerestore/facelib/detection/__init__.py b/hordelib/nodes/facerestore_cf/facelib/detection/__init__.py similarity index 83% rename from hordelib/nodes/facerestore/facelib/detection/__init__.py rename to hordelib/nodes/facerestore_cf/facelib/detection/__init__.py index 7ec791bf..9bf408a8 100644 --- a/hordelib/nodes/facerestore/facelib/detection/__init__.py +++ b/hordelib/nodes/facerestore_cf/facelib/detection/__init__.py @@ -1,116 +1,116 @@ -import os -import torch -from torch import nn -from copy import deepcopy - -from hordelib.nodes.facerestore.facelib.utils import load_file_from_url -from hordelib.nodes.facerestore.facelib.utils import download_pretrained_models -from hordelib.nodes.facerestore.facelib.detection.yolov5face.models.common import Conv - -from .retinaface.retinaface import RetinaFace -from .yolov5face.face_detector import YoloDetector - - -def init_detection_model(model_name, half=False, device="cuda"): - if "retinaface" in model_name: - model = init_retinaface_model(model_name, half, device) - elif "YOLOv5" in model_name: - model = init_yolov5face_model(model_name, device) - else: - raise NotImplementedError(f"{model_name} is not implemented.") - - return model - - -def init_retinaface_model(model_name, half=False, device="cuda"): - if model_name == "retinaface_resnet50": - model = RetinaFace(network_name="resnet50", half=half) - model_url = "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth" - elif model_name == "retinaface_mobile0.25": - model = RetinaFace(network_name="mobile0.25", half=half) - model_url = "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth" - else: - raise NotImplementedError(f"{model_name} is not implemented.") - - model_path = load_file_from_url( - url=model_url, - model_dir="../../models/facedetection", - progress=True, - file_name=None, - ) - load_net = torch.load(model_path, map_location=lambda storage, loc: storage) - # remove unnecessary 'module.' - for k, v in deepcopy(load_net).items(): - if k.startswith("module."): - load_net[k[7:]] = v - load_net.pop(k) - model.load_state_dict(load_net, strict=True) - model.eval() - model = model.to(device) - - return model - - -def init_yolov5face_model(model_name, device="cuda"): - if model_name == "YOLOv5l": - model = YoloDetector( - config_name="nodes/facerestore/facelib/detection/yolov5face/models/yolov5l.yaml", - device=device, - ) - model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth" - elif model_name == "YOLOv5n": - model = YoloDetector( - config_name="nodes/facerestore/facelib/detection/yolov5face/models/yolov5n.yaml", - device=device, - ) - model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth" - else: - raise NotImplementedError(f"{model_name} is not implemented.") - - model_path = load_file_from_url( - url=model_url, - model_dir="../../models/facedetection", - progress=True, - file_name=None, - ) - load_net = torch.load(model_path, map_location=lambda storage, loc: storage) - model.detector.load_state_dict(load_net, strict=True) - model.detector.eval() - model.detector = model.detector.to(device).float() - - for m in model.detector.modules(): - if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: - m.inplace = True # pytorch 1.7.0 compatibility - elif isinstance(m, Conv): - m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility - - return model - - -# Download from Google Drive -# def init_yolov5face_model(model_name, device='cuda'): -# if model_name == 'YOLOv5l': -# model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device) -# f_id = {'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV'} -# elif model_name == 'YOLOv5n': -# model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device) -# f_id = {'yolov5n-face.pth': '1fhcpFvWZqghpGXjYPIne2sw1Fy4yhw6o'} -# else: -# raise NotImplementedError(f'{model_name} is not implemented.') - -# model_path = os.path.join('../../models/facedetection', list(f_id.keys())[0]) -# if not os.path.exists(model_path): -# download_pretrained_models(file_ids=f_id, save_path_root='../../models/facedetection') - -# load_net = torch.load(model_path, map_location=lambda storage, loc: storage) -# model.detector.load_state_dict(load_net, strict=True) -# model.detector.eval() -# model.detector = model.detector.to(device).float() - -# for m in model.detector.modules(): -# if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: -# m.inplace = True # pytorch 1.7.0 compatibility -# elif isinstance(m, Conv): -# m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility - -# return model +import os +import pathlib +from copy import deepcopy + +import torch +from torch import nn + +from hordelib.nodes.facerestore_cf.facelib.detection.yolov5face.models.common import Conv +from hordelib.nodes.facerestore_cf.facelib.utils import download_pretrained_models, load_file_from_url + +from .retinaface.retinaface import RetinaFace +from .yolov5face.face_detector import YoloDetector + + +def init_detection_model(model_name, half=False, device="cuda"): + if "retinaface" in model_name: + model = init_retinaface_model(model_name, half, device) + elif "YOLOv5" in model_name: + model = init_yolov5face_model(model_name, device) + else: + raise NotImplementedError(f"{model_name} is not implemented.") + + return model + + +def init_retinaface_model(model_name, half=False, device="cuda"): + if model_name == "retinaface_resnet50": + model = RetinaFace(network_name="resnet50", half=half) + model_url = "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth" + filename = "detection_Resnet50_Final.pth" + elif model_name == "retinaface_mobile0.25": + model = RetinaFace(network_name="mobile0.25", half=half) + model_url = "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth" + filename = "detection_mobilenet0.25_Final.pth" + else: + raise NotImplementedError(f"{model_name} is not implemented.") + + model_path = load_file_from_url( + url=model_url, + model_dir="../../models/facedetection", + progress=True, + file_name=filename, + ) + load_net = torch.load(model_path, map_location=lambda storage, loc: storage) + # remove unnecessary 'module.' + for k, v in deepcopy(load_net).items(): + if k.startswith("module."): + load_net[k[7:]] = v + load_net.pop(k) + model.load_state_dict(load_net, strict=True) + model.eval() + model = model.to(device) + + return model + + +def init_yolov5face_model(model_name, device="cuda"): + current_dir = str(pathlib.Path(__file__).parent.resolve()) + if model_name == "YOLOv5l": + model = YoloDetector(config_name=current_dir + "/yolov5face/models/yolov5l.yaml", device=device) + model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth" + filename = "yolov5l-face.pth" + elif model_name == "YOLOv5n": + model = YoloDetector(config_name=current_dir + "/yolov5face/models/yolov5n.yaml", device=device) + model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth" + filename = "yolov5n-face.pth" + else: + raise NotImplementedError(f"{model_name} is not implemented.") + + model_path = load_file_from_url( + url=model_url, + model_dir="../../models/facedetection", + progress=True, + file_name=filename, + ) + load_net = torch.load(model_path, map_location=lambda storage, loc: storage) + model.detector.load_state_dict(load_net, strict=True) + model.detector.eval() + model.detector = model.detector.to(device).float() + + for m in model.detector.modules(): + if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: + m.inplace = True # pytorch 1.7.0 compatibility + elif isinstance(m, Conv): + m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility + + return model + + +# Download from Google Drive +# def init_yolov5face_model(model_name, device='cuda'): +# if model_name == 'YOLOv5l': +# model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device) +# f_id = {'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV'} +# elif model_name == 'YOLOv5n': +# model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device) +# f_id = {'yolov5n-face.pth': '1fhcpFvWZqghpGXjYPIne2sw1Fy4yhw6o'} +# else: +# raise NotImplementedError(f'{model_name} is not implemented.') + +# model_path = os.path.join('../../models/facedetection', list(f_id.keys())[0]) +# if not os.path.exists(model_path): +# download_pretrained_models(file_ids=f_id, save_path_root='../../models/facedetection') + +# load_net = torch.load(model_path, map_location=lambda storage, loc: storage) +# model.detector.load_state_dict(load_net, strict=True) +# model.detector.eval() +# model.detector = model.detector.to(device).float() + +# for m in model.detector.modules(): +# if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: +# m.inplace = True # pytorch 1.7.0 compatibility +# elif isinstance(m, Conv): +# m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility + +# return model diff --git a/hordelib/nodes/facerestore/facelib/detection/align_trans.py b/hordelib/nodes/facerestore_cf/facelib/detection/align_trans.py similarity index 75% rename from hordelib/nodes/facerestore/facelib/detection/align_trans.py rename to hordelib/nodes/facerestore_cf/facelib/detection/align_trans.py index 07f1eb36..84e7a16d 100644 --- a/hordelib/nodes/facerestore/facelib/detection/align_trans.py +++ b/hordelib/nodes/facerestore_cf/facelib/detection/align_trans.py @@ -1,219 +1,233 @@ -import cv2 -import numpy as np - -from .matlab_cp2tform import get_similarity_transform_for_cv2 - -# reference facial points, a list of coordinates (x,y) -REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051], [65.53179932, 51.50139999], [48.02519989, 71.73660278], - [33.54930115, 92.3655014], [62.72990036, 92.20410156]] - -DEFAULT_CROP_SIZE = (96, 112) - - -class FaceWarpException(Exception): - - def __str__(self): - return 'In File {}:{}'.format(__file__, super.__str__(self)) - - -def get_reference_facial_points(output_size=None, inner_padding_factor=0.0, outer_padding=(0, 0), default_square=False): - """ - Function: - ---------- - get reference 5 key points according to crop settings: - 0. Set default crop_size: - if default_square: - crop_size = (112, 112) - else: - crop_size = (96, 112) - 1. Pad the crop_size by inner_padding_factor in each side; - 2. Resize crop_size into (output_size - outer_padding*2), - pad into output_size with outer_padding; - 3. Output reference_5point; - Parameters: - ---------- - @output_size: (w, h) or None - size of aligned face image - @inner_padding_factor: (w_factor, h_factor) - padding factor for inner (w, h) - @outer_padding: (w_pad, h_pad) - each row is a pair of coordinates (x, y) - @default_square: True or False - if True: - default crop_size = (112, 112) - else: - default crop_size = (96, 112); - !!! make sure, if output_size is not None: - (output_size - outer_padding) - = some_scale * (default crop_size * (1.0 + - inner_padding_factor)) - Returns: - ---------- - @reference_5point: 5x2 np.array - each row is a pair of transformed coordinates (x, y) - """ - - tmp_5pts = np.array(REFERENCE_FACIAL_POINTS) - tmp_crop_size = np.array(DEFAULT_CROP_SIZE) - - # 0) make the inner region a square - if default_square: - size_diff = max(tmp_crop_size) - tmp_crop_size - tmp_5pts += size_diff / 2 - tmp_crop_size += size_diff - - if (output_size and output_size[0] == tmp_crop_size[0] and output_size[1] == tmp_crop_size[1]): - - return tmp_5pts - - if (inner_padding_factor == 0 and outer_padding == (0, 0)): - if output_size is None: - return tmp_5pts - else: - raise FaceWarpException('No paddings to do, output_size must be None or {}'.format(tmp_crop_size)) - - # check output size - if not (0 <= inner_padding_factor <= 1.0): - raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)') - - if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) and output_size is None): - output_size = tmp_crop_size * \ - (1 + inner_padding_factor * 2).astype(np.int32) - output_size += np.array(outer_padding) - if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]): - raise FaceWarpException('Not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1])') - - # 1) pad the inner region according inner_padding_factor - if inner_padding_factor > 0: - size_diff = tmp_crop_size * inner_padding_factor * 2 - tmp_5pts += size_diff / 2 - tmp_crop_size += np.round(size_diff).astype(np.int32) - - # 2) resize the padded inner region - size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2 - - if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]: - raise FaceWarpException('Must have (output_size - outer_padding)' - '= some_scale * (crop_size * (1.0 + inner_padding_factor)') - - scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0] - tmp_5pts = tmp_5pts * scale_factor - # size_diff = tmp_crop_size * (scale_factor - min(scale_factor)) - # tmp_5pts = tmp_5pts + size_diff / 2 - tmp_crop_size = size_bf_outer_pad - - # 3) add outer_padding to make output_size - reference_5point = tmp_5pts + np.array(outer_padding) - tmp_crop_size = output_size - - return reference_5point - - -def get_affine_transform_matrix(src_pts, dst_pts): - """ - Function: - ---------- - get affine transform matrix 'tfm' from src_pts to dst_pts - Parameters: - ---------- - @src_pts: Kx2 np.array - source points matrix, each row is a pair of coordinates (x, y) - @dst_pts: Kx2 np.array - destination points matrix, each row is a pair of coordinates (x, y) - Returns: - ---------- - @tfm: 2x3 np.array - transform matrix from src_pts to dst_pts - """ - - tfm = np.float32([[1, 0, 0], [0, 1, 0]]) - n_pts = src_pts.shape[0] - ones = np.ones((n_pts, 1), src_pts.dtype) - src_pts_ = np.hstack([src_pts, ones]) - dst_pts_ = np.hstack([dst_pts, ones]) - - A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_) - - if rank == 3: - tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]]) - elif rank == 2: - tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]]) - - return tfm - - -def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type='smilarity'): - """ - Function: - ---------- - apply affine transform 'trans' to uv - Parameters: - ---------- - @src_img: 3x3 np.array - input image - @facial_pts: could be - 1)a list of K coordinates (x,y) - or - 2) Kx2 or 2xK np.array - each row or col is a pair of coordinates (x, y) - @reference_pts: could be - 1) a list of K coordinates (x,y) - or - 2) Kx2 or 2xK np.array - each row or col is a pair of coordinates (x, y) - or - 3) None - if None, use default reference facial points - @crop_size: (w, h) - output face image size - @align_type: transform type, could be one of - 1) 'similarity': use similarity transform - 2) 'cv2_affine': use the first 3 points to do affine transform, - by calling cv2.getAffineTransform() - 3) 'affine': use all points to do affine transform - Returns: - ---------- - @face_img: output face image with size (w, h) = @crop_size - """ - - if reference_pts is None: - if crop_size[0] == 96 and crop_size[1] == 112: - reference_pts = REFERENCE_FACIAL_POINTS - else: - default_square = False - inner_padding_factor = 0 - outer_padding = (0, 0) - output_size = crop_size - - reference_pts = get_reference_facial_points(output_size, inner_padding_factor, outer_padding, - default_square) - - ref_pts = np.float32(reference_pts) - ref_pts_shp = ref_pts.shape - if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2: - raise FaceWarpException('reference_pts.shape must be (K,2) or (2,K) and K>2') - - if ref_pts_shp[0] == 2: - ref_pts = ref_pts.T - - src_pts = np.float32(facial_pts) - src_pts_shp = src_pts.shape - if max(src_pts_shp) < 3 or min(src_pts_shp) != 2: - raise FaceWarpException('facial_pts.shape must be (K,2) or (2,K) and K>2') - - if src_pts_shp[0] == 2: - src_pts = src_pts.T - - if src_pts.shape != ref_pts.shape: - raise FaceWarpException('facial_pts and reference_pts must have the same shape') - - if align_type == 'cv2_affine': - tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3]) - elif align_type == 'affine': - tfm = get_affine_transform_matrix(src_pts, ref_pts) - else: - tfm = get_similarity_transform_for_cv2(src_pts, ref_pts) - - face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1])) - - return face_img +import cv2 +import numpy as np + +from .matlab_cp2tform import get_similarity_transform_for_cv2 + +# reference facial points, a list of coordinates (x,y) +REFERENCE_FACIAL_POINTS = [ + [30.29459953, 51.69630051], + [65.53179932, 51.50139999], + [48.02519989, 71.73660278], + [33.54930115, 92.3655014], + [62.72990036, 92.20410156], +] + +DEFAULT_CROP_SIZE = (96, 112) + + +class FaceWarpException(Exception): + + def __str__(self): + return f"In File {__file__}:{super.__str__(self)}" + + +def get_reference_facial_points( + output_size=None, + inner_padding_factor=0.0, + outer_padding=(0, 0), + default_square=False, +): + """ + Function: + ---------- + get reference 5 key points according to crop settings: + 0. Set default crop_size: + if default_square: + crop_size = (112, 112) + else: + crop_size = (96, 112) + 1. Pad the crop_size by inner_padding_factor in each side; + 2. Resize crop_size into (output_size - outer_padding*2), + pad into output_size with outer_padding; + 3. Output reference_5point; + Parameters: + ---------- + @output_size: (w, h) or None + size of aligned face image + @inner_padding_factor: (w_factor, h_factor) + padding factor for inner (w, h) + @outer_padding: (w_pad, h_pad) + each row is a pair of coordinates (x, y) + @default_square: True or False + if True: + default crop_size = (112, 112) + else: + default crop_size = (96, 112); + !!! make sure, if output_size is not None: + (output_size - outer_padding) + = some_scale * (default crop_size * (1.0 + + inner_padding_factor)) + Returns: + ---------- + @reference_5point: 5x2 np.array + each row is a pair of transformed coordinates (x, y) + """ + + tmp_5pts = np.array(REFERENCE_FACIAL_POINTS) + tmp_crop_size = np.array(DEFAULT_CROP_SIZE) + + # 0) make the inner region a square + if default_square: + size_diff = max(tmp_crop_size) - tmp_crop_size + tmp_5pts += size_diff / 2 + tmp_crop_size += size_diff + + if output_size and output_size[0] == tmp_crop_size[0] and output_size[1] == tmp_crop_size[1]: + + return tmp_5pts + + if inner_padding_factor == 0 and outer_padding == (0, 0): + if output_size is None: + return tmp_5pts + else: + raise FaceWarpException(f"No paddings to do, output_size must be None or {tmp_crop_size}") + + # check output size + if not (0 <= inner_padding_factor <= 1.0): + raise FaceWarpException("Not (0 <= inner_padding_factor <= 1.0)") + + if (inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) and output_size is None: + output_size = tmp_crop_size * (1 + inner_padding_factor * 2).astype(np.int32) + output_size += np.array(outer_padding) + if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]): + raise FaceWarpException("Not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1])") + + # 1) pad the inner region according inner_padding_factor + if inner_padding_factor > 0: + size_diff = tmp_crop_size * inner_padding_factor * 2 + tmp_5pts += size_diff / 2 + tmp_crop_size += np.round(size_diff).astype(np.int32) + + # 2) resize the padded inner region + size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2 + + if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]: + raise FaceWarpException( + "Must have (output_size - outer_padding)" "= some_scale * (crop_size * (1.0 + inner_padding_factor)", + ) + + scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0] + tmp_5pts = tmp_5pts * scale_factor + # size_diff = tmp_crop_size * (scale_factor - min(scale_factor)) + # tmp_5pts = tmp_5pts + size_diff / 2 + tmp_crop_size = size_bf_outer_pad + + # 3) add outer_padding to make output_size + reference_5point = tmp_5pts + np.array(outer_padding) + tmp_crop_size = output_size + + return reference_5point + + +def get_affine_transform_matrix(src_pts, dst_pts): + """ + Function: + ---------- + get affine transform matrix 'tfm' from src_pts to dst_pts + Parameters: + ---------- + @src_pts: Kx2 np.array + source points matrix, each row is a pair of coordinates (x, y) + @dst_pts: Kx2 np.array + destination points matrix, each row is a pair of coordinates (x, y) + Returns: + ---------- + @tfm: 2x3 np.array + transform matrix from src_pts to dst_pts + """ + + tfm = np.float32([[1, 0, 0], [0, 1, 0]]) + n_pts = src_pts.shape[0] + ones = np.ones((n_pts, 1), src_pts.dtype) + src_pts_ = np.hstack([src_pts, ones]) + dst_pts_ = np.hstack([dst_pts, ones]) + + A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_) + + if rank == 3: + tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]]) + elif rank == 2: + tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]]) + + return tfm + + +def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type="smilarity"): + """ + Function: + ---------- + apply affine transform 'trans' to uv + Parameters: + ---------- + @src_img: 3x3 np.array + input image + @facial_pts: could be + 1)a list of K coordinates (x,y) + or + 2) Kx2 or 2xK np.array + each row or col is a pair of coordinates (x, y) + @reference_pts: could be + 1) a list of K coordinates (x,y) + or + 2) Kx2 or 2xK np.array + each row or col is a pair of coordinates (x, y) + or + 3) None + if None, use default reference facial points + @crop_size: (w, h) + output face image size + @align_type: transform type, could be one of + 1) 'similarity': use similarity transform + 2) 'cv2_affine': use the first 3 points to do affine transform, + by calling cv2.getAffineTransform() + 3) 'affine': use all points to do affine transform + Returns: + ---------- + @face_img: output face image with size (w, h) = @crop_size + """ + + if reference_pts is None: + if crop_size[0] == 96 and crop_size[1] == 112: + reference_pts = REFERENCE_FACIAL_POINTS + else: + default_square = False + inner_padding_factor = 0 + outer_padding = (0, 0) + output_size = crop_size + + reference_pts = get_reference_facial_points( + output_size, + inner_padding_factor, + outer_padding, + default_square, + ) + + ref_pts = np.float32(reference_pts) + ref_pts_shp = ref_pts.shape + if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2: + raise FaceWarpException("reference_pts.shape must be (K,2) or (2,K) and K>2") + + if ref_pts_shp[0] == 2: + ref_pts = ref_pts.T + + src_pts = np.float32(facial_pts) + src_pts_shp = src_pts.shape + if max(src_pts_shp) < 3 or min(src_pts_shp) != 2: + raise FaceWarpException("facial_pts.shape must be (K,2) or (2,K) and K>2") + + if src_pts_shp[0] == 2: + src_pts = src_pts.T + + if src_pts.shape != ref_pts.shape: + raise FaceWarpException("facial_pts and reference_pts must have the same shape") + + if align_type == "cv2_affine": + tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3]) + elif align_type == "affine": + tfm = get_affine_transform_matrix(src_pts, ref_pts) + else: + tfm = get_similarity_transform_for_cv2(src_pts, ref_pts) + + face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1])) + + return face_img diff --git a/hordelib/nodes/facerestore/facelib/detection/matlab_cp2tform.py b/hordelib/nodes/facerestore_cf/facelib/detection/matlab_cp2tform.py similarity index 88% rename from hordelib/nodes/facerestore/facelib/detection/matlab_cp2tform.py rename to hordelib/nodes/facerestore_cf/facelib/detection/matlab_cp2tform.py index b2a8b54a..7bd7b4cd 100644 --- a/hordelib/nodes/facerestore/facelib/detection/matlab_cp2tform.py +++ b/hordelib/nodes/facerestore_cf/facelib/detection/matlab_cp2tform.py @@ -1,317 +1,316 @@ -import numpy as np -from numpy.linalg import inv, lstsq -from numpy.linalg import matrix_rank as rank -from numpy.linalg import norm - - -class MatlabCp2tormException(Exception): - - def __str__(self): - return 'In File {}:{}'.format(__file__, super.__str__(self)) - - -def tformfwd(trans, uv): - """ - Function: - ---------- - apply affine transform 'trans' to uv - - Parameters: - ---------- - @trans: 3x3 np.array - transform matrix - @uv: Kx2 np.array - each row is a pair of coordinates (x, y) - - Returns: - ---------- - @xy: Kx2 np.array - each row is a pair of transformed coordinates (x, y) - """ - uv = np.hstack((uv, np.ones((uv.shape[0], 1)))) - xy = np.dot(uv, trans) - xy = xy[:, 0:-1] - return xy - - -def tforminv(trans, uv): - """ - Function: - ---------- - apply the inverse of affine transform 'trans' to uv - - Parameters: - ---------- - @trans: 3x3 np.array - transform matrix - @uv: Kx2 np.array - each row is a pair of coordinates (x, y) - - Returns: - ---------- - @xy: Kx2 np.array - each row is a pair of inverse-transformed coordinates (x, y) - """ - Tinv = inv(trans) - xy = tformfwd(Tinv, uv) - return xy - - -def findNonreflectiveSimilarity(uv, xy, options=None): - options = {'K': 2} - - K = options['K'] - M = xy.shape[0] - x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector - y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector - - tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1)))) - tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1)))) - X = np.vstack((tmp1, tmp2)) - - u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector - v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector - U = np.vstack((u, v)) - - # We know that X * r = U - if rank(X) >= 2 * K: - r, _, _, _ = lstsq(X, U, rcond=-1) - r = np.squeeze(r) - else: - raise Exception('cp2tform:twoUniquePointsReq') - sc = r[0] - ss = r[1] - tx = r[2] - ty = r[3] - - Tinv = np.array([[sc, -ss, 0], [ss, sc, 0], [tx, ty, 1]]) - T = inv(Tinv) - T[:, 2] = np.array([0, 0, 1]) - - return T, Tinv - - -def findSimilarity(uv, xy, options=None): - options = {'K': 2} - - # uv = np.array(uv) - # xy = np.array(xy) - - # Solve for trans1 - trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options) - - # Solve for trans2 - - # manually reflect the xy data across the Y-axis - xyR = xy - xyR[:, 0] = -1 * xyR[:, 0] - - trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options) - - # manually reflect the tform to undo the reflection done on xyR - TreflectY = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]) - - trans2 = np.dot(trans2r, TreflectY) - - # Figure out if trans1 or trans2 is better - xy1 = tformfwd(trans1, uv) - norm1 = norm(xy1 - xy) - - xy2 = tformfwd(trans2, uv) - norm2 = norm(xy2 - xy) - - if norm1 <= norm2: - return trans1, trans1_inv - else: - trans2_inv = inv(trans2) - return trans2, trans2_inv - - -def get_similarity_transform(src_pts, dst_pts, reflective=True): - """ - Function: - ---------- - Find Similarity Transform Matrix 'trans': - u = src_pts[:, 0] - v = src_pts[:, 1] - x = dst_pts[:, 0] - y = dst_pts[:, 1] - [x, y, 1] = [u, v, 1] * trans - - Parameters: - ---------- - @src_pts: Kx2 np.array - source points, each row is a pair of coordinates (x, y) - @dst_pts: Kx2 np.array - destination points, each row is a pair of transformed - coordinates (x, y) - @reflective: True or False - if True: - use reflective similarity transform - else: - use non-reflective similarity transform - - Returns: - ---------- - @trans: 3x3 np.array - transform matrix from uv to xy - trans_inv: 3x3 np.array - inverse of trans, transform matrix from xy to uv - """ - - if reflective: - trans, trans_inv = findSimilarity(src_pts, dst_pts) - else: - trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts) - - return trans, trans_inv - - -def cvt_tform_mat_for_cv2(trans): - """ - Function: - ---------- - Convert Transform Matrix 'trans' into 'cv2_trans' which could be - directly used by cv2.warpAffine(): - u = src_pts[:, 0] - v = src_pts[:, 1] - x = dst_pts[:, 0] - y = dst_pts[:, 1] - [x, y].T = cv_trans * [u, v, 1].T - - Parameters: - ---------- - @trans: 3x3 np.array - transform matrix from uv to xy - - Returns: - ---------- - @cv2_trans: 2x3 np.array - transform matrix from src_pts to dst_pts, could be directly used - for cv2.warpAffine() - """ - cv2_trans = trans[:, 0:2].T - - return cv2_trans - - -def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True): - """ - Function: - ---------- - Find Similarity Transform Matrix 'cv2_trans' which could be - directly used by cv2.warpAffine(): - u = src_pts[:, 0] - v = src_pts[:, 1] - x = dst_pts[:, 0] - y = dst_pts[:, 1] - [x, y].T = cv_trans * [u, v, 1].T - - Parameters: - ---------- - @src_pts: Kx2 np.array - source points, each row is a pair of coordinates (x, y) - @dst_pts: Kx2 np.array - destination points, each row is a pair of transformed - coordinates (x, y) - reflective: True or False - if True: - use reflective similarity transform - else: - use non-reflective similarity transform - - Returns: - ---------- - @cv2_trans: 2x3 np.array - transform matrix from src_pts to dst_pts, could be directly used - for cv2.warpAffine() - """ - trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective) - cv2_trans = cvt_tform_mat_for_cv2(trans) - - return cv2_trans - - -if __name__ == '__main__': - """ - u = [0, 6, -2] - v = [0, 3, 5] - x = [-1, 0, 4] - y = [-1, -10, 4] - - # In Matlab, run: - # - # uv = [u'; v']; - # xy = [x'; y']; - # tform_sim=cp2tform(uv,xy,'similarity'); - # - # trans = tform_sim.tdata.T - # ans = - # -0.0764 -1.6190 0 - # 1.6190 -0.0764 0 - # -3.2156 0.0290 1.0000 - # trans_inv = tform_sim.tdata.Tinv - # ans = - # - # -0.0291 0.6163 0 - # -0.6163 -0.0291 0 - # -0.0756 1.9826 1.0000 - # xy_m=tformfwd(tform_sim, u,v) - # - # xy_m = - # - # -3.2156 0.0290 - # 1.1833 -9.9143 - # 5.0323 2.8853 - # uv_m=tforminv(tform_sim, x,y) - # - # uv_m = - # - # 0.5698 1.3953 - # 6.0872 2.2733 - # -2.6570 4.3314 - """ - u = [0, 6, -2] - v = [0, 3, 5] - x = [-1, 0, 4] - y = [-1, -10, 4] - - uv = np.array((u, v)).T - xy = np.array((x, y)).T - - print('\n--->uv:') - print(uv) - print('\n--->xy:') - print(xy) - - trans, trans_inv = get_similarity_transform(uv, xy) - - print('\n--->trans matrix:') - print(trans) - - print('\n--->trans_inv matrix:') - print(trans_inv) - - print('\n---> apply transform to uv') - print('\nxy_m = uv_augmented * trans') - uv_aug = np.hstack((uv, np.ones((uv.shape[0], 1)))) - xy_m = np.dot(uv_aug, trans) - print(xy_m) - - print('\nxy_m = tformfwd(trans, uv)') - xy_m = tformfwd(trans, uv) - print(xy_m) - - print('\n---> apply inverse transform to xy') - print('\nuv_m = xy_augmented * trans_inv') - xy_aug = np.hstack((xy, np.ones((xy.shape[0], 1)))) - uv_m = np.dot(xy_aug, trans_inv) - print(uv_m) - - print('\nuv_m = tformfwd(trans_inv, xy)') - uv_m = tformfwd(trans_inv, xy) - print(uv_m) - - uv_m = tforminv(trans, xy) - print('\nuv_m = tforminv(trans, xy)') - print(uv_m) +import numpy as np +from numpy.linalg import inv, lstsq, norm +from numpy.linalg import matrix_rank as rank + + +class MatlabCp2tormException(Exception): + + def __str__(self): + return f"In File {__file__}:{super.__str__(self)}" + + +def tformfwd(trans, uv): + """ + Function: + ---------- + apply affine transform 'trans' to uv + + Parameters: + ---------- + @trans: 3x3 np.array + transform matrix + @uv: Kx2 np.array + each row is a pair of coordinates (x, y) + + Returns: + ---------- + @xy: Kx2 np.array + each row is a pair of transformed coordinates (x, y) + """ + uv = np.hstack((uv, np.ones((uv.shape[0], 1)))) + xy = np.dot(uv, trans) + xy = xy[:, 0:-1] + return xy + + +def tforminv(trans, uv): + """ + Function: + ---------- + apply the inverse of affine transform 'trans' to uv + + Parameters: + ---------- + @trans: 3x3 np.array + transform matrix + @uv: Kx2 np.array + each row is a pair of coordinates (x, y) + + Returns: + ---------- + @xy: Kx2 np.array + each row is a pair of inverse-transformed coordinates (x, y) + """ + Tinv = inv(trans) + xy = tformfwd(Tinv, uv) + return xy + + +def findNonreflectiveSimilarity(uv, xy, options=None): + options = {"K": 2} + + K = options["K"] + M = xy.shape[0] + x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector + y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector + + tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1)))) + tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1)))) + X = np.vstack((tmp1, tmp2)) + + u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector + v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector + U = np.vstack((u, v)) + + # We know that X * r = U + if rank(X) >= 2 * K: + r, _, _, _ = lstsq(X, U, rcond=-1) + r = np.squeeze(r) + else: + raise Exception("cp2tform:twoUniquePointsReq") + sc = r[0] + ss = r[1] + tx = r[2] + ty = r[3] + + Tinv = np.array([[sc, -ss, 0], [ss, sc, 0], [tx, ty, 1]]) + T = inv(Tinv) + T[:, 2] = np.array([0, 0, 1]) + + return T, Tinv + + +def findSimilarity(uv, xy, options=None): + options = {"K": 2} + + # uv = np.array(uv) + # xy = np.array(xy) + + # Solve for trans1 + trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options) + + # Solve for trans2 + + # manually reflect the xy data across the Y-axis + xyR = xy + xyR[:, 0] = -1 * xyR[:, 0] + + trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options) + + # manually reflect the tform to undo the reflection done on xyR + TreflectY = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]) + + trans2 = np.dot(trans2r, TreflectY) + + # Figure out if trans1 or trans2 is better + xy1 = tformfwd(trans1, uv) + norm1 = norm(xy1 - xy) + + xy2 = tformfwd(trans2, uv) + norm2 = norm(xy2 - xy) + + if norm1 <= norm2: + return trans1, trans1_inv + else: + trans2_inv = inv(trans2) + return trans2, trans2_inv + + +def get_similarity_transform(src_pts, dst_pts, reflective=True): + """ + Function: + ---------- + Find Similarity Transform Matrix 'trans': + u = src_pts[:, 0] + v = src_pts[:, 1] + x = dst_pts[:, 0] + y = dst_pts[:, 1] + [x, y, 1] = [u, v, 1] * trans + + Parameters: + ---------- + @src_pts: Kx2 np.array + source points, each row is a pair of coordinates (x, y) + @dst_pts: Kx2 np.array + destination points, each row is a pair of transformed + coordinates (x, y) + @reflective: True or False + if True: + use reflective similarity transform + else: + use non-reflective similarity transform + + Returns: + ---------- + @trans: 3x3 np.array + transform matrix from uv to xy + trans_inv: 3x3 np.array + inverse of trans, transform matrix from xy to uv + """ + + if reflective: + trans, trans_inv = findSimilarity(src_pts, dst_pts) + else: + trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts) + + return trans, trans_inv + + +def cvt_tform_mat_for_cv2(trans): + """ + Function: + ---------- + Convert Transform Matrix 'trans' into 'cv2_trans' which could be + directly used by cv2.warpAffine(): + u = src_pts[:, 0] + v = src_pts[:, 1] + x = dst_pts[:, 0] + y = dst_pts[:, 1] + [x, y].T = cv_trans * [u, v, 1].T + + Parameters: + ---------- + @trans: 3x3 np.array + transform matrix from uv to xy + + Returns: + ---------- + @cv2_trans: 2x3 np.array + transform matrix from src_pts to dst_pts, could be directly used + for cv2.warpAffine() + """ + cv2_trans = trans[:, 0:2].T + + return cv2_trans + + +def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True): + """ + Function: + ---------- + Find Similarity Transform Matrix 'cv2_trans' which could be + directly used by cv2.warpAffine(): + u = src_pts[:, 0] + v = src_pts[:, 1] + x = dst_pts[:, 0] + y = dst_pts[:, 1] + [x, y].T = cv_trans * [u, v, 1].T + + Parameters: + ---------- + @src_pts: Kx2 np.array + source points, each row is a pair of coordinates (x, y) + @dst_pts: Kx2 np.array + destination points, each row is a pair of transformed + coordinates (x, y) + reflective: True or False + if True: + use reflective similarity transform + else: + use non-reflective similarity transform + + Returns: + ---------- + @cv2_trans: 2x3 np.array + transform matrix from src_pts to dst_pts, could be directly used + for cv2.warpAffine() + """ + trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective) + cv2_trans = cvt_tform_mat_for_cv2(trans) + + return cv2_trans + + +if __name__ == "__main__": + """ + u = [0, 6, -2] + v = [0, 3, 5] + x = [-1, 0, 4] + y = [-1, -10, 4] + + # In Matlab, run: + # + # uv = [u'; v']; + # xy = [x'; y']; + # tform_sim=cp2tform(uv,xy,'similarity'); + # + # trans = tform_sim.tdata.T + # ans = + # -0.0764 -1.6190 0 + # 1.6190 -0.0764 0 + # -3.2156 0.0290 1.0000 + # trans_inv = tform_sim.tdata.Tinv + # ans = + # + # -0.0291 0.6163 0 + # -0.6163 -0.0291 0 + # -0.0756 1.9826 1.0000 + # xy_m=tformfwd(tform_sim, u,v) + # + # xy_m = + # + # -3.2156 0.0290 + # 1.1833 -9.9143 + # 5.0323 2.8853 + # uv_m=tforminv(tform_sim, x,y) + # + # uv_m = + # + # 0.5698 1.3953 + # 6.0872 2.2733 + # -2.6570 4.3314 + """ + u = [0, 6, -2] + v = [0, 3, 5] + x = [-1, 0, 4] + y = [-1, -10, 4] + + uv = np.array((u, v)).T + xy = np.array((x, y)).T + + print("\n--->uv:") + print(uv) + print("\n--->xy:") + print(xy) + + trans, trans_inv = get_similarity_transform(uv, xy) + + print("\n--->trans matrix:") + print(trans) + + print("\n--->trans_inv matrix:") + print(trans_inv) + + print("\n---> apply transform to uv") + print("\nxy_m = uv_augmented * trans") + uv_aug = np.hstack((uv, np.ones((uv.shape[0], 1)))) + xy_m = np.dot(uv_aug, trans) + print(xy_m) + + print("\nxy_m = tformfwd(trans, uv)") + xy_m = tformfwd(trans, uv) + print(xy_m) + + print("\n---> apply inverse transform to xy") + print("\nuv_m = xy_augmented * trans_inv") + xy_aug = np.hstack((xy, np.ones((xy.shape[0], 1)))) + uv_m = np.dot(xy_aug, trans_inv) + print(uv_m) + + print("\nuv_m = tformfwd(trans_inv, xy)") + uv_m = tformfwd(trans_inv, xy) + print(uv_m) + + uv_m = tforminv(trans, xy) + print("\nuv_m = tforminv(trans, xy)") + print(uv_m) diff --git a/hordelib/nodes/facerestore/facelib/detection/retinaface/retinaface.py b/hordelib/nodes/facerestore_cf/facelib/detection/retinaface/retinaface.py similarity index 83% rename from hordelib/nodes/facerestore/facelib/detection/retinaface/retinaface.py rename to hordelib/nodes/facerestore_cf/facelib/detection/retinaface/retinaface.py index 16881365..bfe398b4 100644 --- a/hordelib/nodes/facerestore/facelib/detection/retinaface/retinaface.py +++ b/hordelib/nodes/facerestore_cf/facelib/detection/retinaface/retinaface.py @@ -1,420 +1,389 @@ -import cv2 -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from PIL import Image -from torchvision.models._utils import IntermediateLayerGetter as IntermediateLayerGetter - -from hordelib.nodes.facerestore.facelib.detection.align_trans import ( - get_reference_facial_points, - warp_and_crop_face, -) -from hordelib.nodes.facerestore.facelib.detection.retinaface.retinaface_net import ( - FPN, - SSH, - MobileNetV1, - make_bbox_head, - make_class_head, - make_landmark_head, -) -from hordelib.nodes.facerestore.facelib.detection.retinaface.retinaface_utils import ( - PriorBox, - batched_decode, - batched_decode_landm, - decode, - decode_landm, - py_cpu_nms, -) - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - -def generate_config(network_name): - cfg_mnet = { - "name": "mobilenet0.25", - "min_sizes": [[16, 32], [64, 128], [256, 512]], - "steps": [8, 16, 32], - "variance": [0.1, 0.2], - "clip": False, - "loc_weight": 2.0, - "gpu_train": True, - "batch_size": 32, - "ngpu": 1, - "epoch": 250, - "decay1": 190, - "decay2": 220, - "image_size": 640, - "return_layers": {"stage1": 1, "stage2": 2, "stage3": 3}, - "in_channel": 32, - "out_channel": 64, - } - - cfg_re50 = { - "name": "Resnet50", - "min_sizes": [[16, 32], [64, 128], [256, 512]], - "steps": [8, 16, 32], - "variance": [0.1, 0.2], - "clip": False, - "loc_weight": 2.0, - "gpu_train": True, - "batch_size": 24, - "ngpu": 4, - "epoch": 100, - "decay1": 70, - "decay2": 90, - "image_size": 840, - "return_layers": {"layer2": 1, "layer3": 2, "layer4": 3}, - "in_channel": 256, - "out_channel": 256, - } - - if network_name == "mobile0.25": - return cfg_mnet - elif network_name == "resnet50": - return cfg_re50 - else: - raise NotImplementedError(f"network_name={network_name}") - - -class RetinaFace(nn.Module): - def __init__(self, network_name="resnet50", half=False, phase="test"): - super(RetinaFace, self).__init__() - self.half_inference = half - cfg = generate_config(network_name) - self.backbone = cfg["name"] - - self.model_name = f"retinaface_{network_name}" - self.cfg = cfg - self.phase = phase - self.target_size, self.max_size = 1600, 2150 - self.resize, self.scale, self.scale1 = 1.0, None, None - self.mean_tensor = torch.tensor([[[[104.0]], [[117.0]], [[123.0]]]]).to(device) - self.reference = get_reference_facial_points(default_square=True) - # Build network. - backbone = None - if cfg["name"] == "mobilenet0.25": - backbone = MobileNetV1() - self.body = IntermediateLayerGetter(backbone, cfg["return_layers"]) - elif cfg["name"] == "Resnet50": - import torchvision.models as models - - backbone = models.resnet50(pretrained=False) - self.body = IntermediateLayerGetter(backbone, cfg["return_layers"]) - - in_channels_stage2 = cfg["in_channel"] - in_channels_list = [ - in_channels_stage2 * 2, - in_channels_stage2 * 4, - in_channels_stage2 * 8, - ] - - out_channels = cfg["out_channel"] - self.fpn = FPN(in_channels_list, out_channels) - self.ssh1 = SSH(out_channels, out_channels) - self.ssh2 = SSH(out_channels, out_channels) - self.ssh3 = SSH(out_channels, out_channels) - - self.ClassHead = make_class_head(fpn_num=3, inchannels=cfg["out_channel"]) - self.BboxHead = make_bbox_head(fpn_num=3, inchannels=cfg["out_channel"]) - self.LandmarkHead = make_landmark_head(fpn_num=3, inchannels=cfg["out_channel"]) - - self.to(device) - self.eval() - if self.half_inference: - self.half() - - def forward(self, inputs): - out = self.body(inputs) - - if self.backbone == "mobilenet0.25" or self.backbone == "Resnet50": - out = list(out.values()) - # FPN - fpn = self.fpn(out) - - # SSH - feature1 = self.ssh1(fpn[0]) - feature2 = self.ssh2(fpn[1]) - feature3 = self.ssh3(fpn[2]) - features = [feature1, feature2, feature3] - - bbox_regressions = torch.cat( - [self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1 - ) - classifications = torch.cat( - [self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1 - ) - tmp = [self.LandmarkHead[i](feature) for i, feature in enumerate(features)] - ldm_regressions = torch.cat(tmp, dim=1) - - if self.phase == "train": - output = (bbox_regressions, classifications, ldm_regressions) - else: - output = ( - bbox_regressions, - F.softmax(classifications, dim=-1), - ldm_regressions, - ) - return output - - def __detect_faces(self, inputs): - # get scale - height, width = inputs.shape[2:] - self.scale = torch.tensor( - [width, height, width, height], dtype=torch.float32 - ).to(device) - tmp = [ - width, - height, - width, - height, - width, - height, - width, - height, - width, - height, - ] - self.scale1 = torch.tensor(tmp, dtype=torch.float32).to(device) - - # forawrd - inputs = inputs.to(device) - if self.half_inference: - inputs = inputs.half() - loc, conf, landmarks = self(inputs) - - # get priorbox - priorbox = PriorBox(self.cfg, image_size=inputs.shape[2:]) - priors = priorbox.forward().to(device) - - return loc, conf, landmarks, priors - - # single image detection - def transform(self, image, use_origin_size): - # convert to opencv format - if isinstance(image, Image.Image): - image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) - image = image.astype(np.float32) - - # testing scale - im_size_min = np.min(image.shape[0:2]) - im_size_max = np.max(image.shape[0:2]) - resize = float(self.target_size) / float(im_size_min) - - # prevent bigger axis from being more than max_size - if np.round(resize * im_size_max) > self.max_size: - resize = float(self.max_size) / float(im_size_max) - resize = 1 if use_origin_size else resize - - # resize - if resize != 1: - image = cv2.resize( - image, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR - ) - - # convert to torch.tensor format - # image -= (104, 117, 123) - image = image.transpose(2, 0, 1) - image = torch.from_numpy(image).unsqueeze(0) - - return image, resize - - def detect_faces( - self, - image, - conf_threshold=0.8, - nms_threshold=0.4, - use_origin_size=True, - ): - """ - Params: - imgs: BGR image - """ - image, self.resize = self.transform(image, use_origin_size) - image = image.to(device) - if self.half_inference: - image = image.half() - image = image - self.mean_tensor - - loc, conf, landmarks, priors = self.__detect_faces(image) - - boxes = decode(loc.data.squeeze(0), priors.data, self.cfg["variance"]) - boxes = boxes * self.scale / self.resize - boxes = boxes.cpu().numpy() - - scores = conf.squeeze(0).data.cpu().numpy()[:, 1] - - landmarks = decode_landm(landmarks.squeeze(0), priors, self.cfg["variance"]) - landmarks = landmarks * self.scale1 / self.resize - landmarks = landmarks.cpu().numpy() - - # ignore low scores - inds = np.where(scores > conf_threshold)[0] - boxes, landmarks, scores = boxes[inds], landmarks[inds], scores[inds] - - # sort - order = scores.argsort()[::-1] - boxes, landmarks, scores = boxes[order], landmarks[order], scores[order] - - # do NMS - bounding_boxes = np.hstack((boxes, scores[:, np.newaxis])).astype( - np.float32, copy=False - ) - keep = py_cpu_nms(bounding_boxes, nms_threshold) - bounding_boxes, landmarks = bounding_boxes[keep, :], landmarks[keep] - # self.t['forward_pass'].toc() - # print(self.t['forward_pass'].average_time) - # import sys - # sys.stdout.flush() - return np.concatenate((bounding_boxes, landmarks), axis=1) - - def __align_multi(self, image, boxes, landmarks, limit=None): - if len(boxes) < 1: - return [], [] - - if limit: - boxes = boxes[:limit] - landmarks = landmarks[:limit] - - faces = [] - for landmark in landmarks: - facial5points = [[landmark[2 * j], landmark[2 * j + 1]] for j in range(5)] - - warped_face = warp_and_crop_face( - np.array(image), facial5points, self.reference, crop_size=(112, 112) - ) - faces.append(warped_face) - - return np.concatenate((boxes, landmarks), axis=1), faces - - def align_multi(self, img, conf_threshold=0.8, limit=None): - rlt = self.detect_faces(img, conf_threshold=conf_threshold) - boxes, landmarks = rlt[:, 0:5], rlt[:, 5:] - - return self.__align_multi(img, boxes, landmarks, limit) - - # batched detection - def batched_transform(self, frames, use_origin_size): - """ - Arguments: - frames: a list of PIL.Image, or torch.Tensor(shape=[n, h, w, c], - type=np.float32, BGR format). - use_origin_size: whether to use origin size. - """ - from_PIL = True if isinstance(frames[0], Image.Image) else False - - # convert to opencv format - if from_PIL: - frames = [ - cv2.cvtColor(np.asarray(frame), cv2.COLOR_RGB2BGR) for frame in frames - ] - frames = np.asarray(frames, dtype=np.float32) - - # testing scale - im_size_min = np.min(frames[0].shape[0:2]) - im_size_max = np.max(frames[0].shape[0:2]) - resize = float(self.target_size) / float(im_size_min) - - # prevent bigger axis from being more than max_size - if np.round(resize * im_size_max) > self.max_size: - resize = float(self.max_size) / float(im_size_max) - resize = 1 if use_origin_size else resize - - # resize - if resize != 1: - if not from_PIL: - frames = F.interpolate(frames, scale_factor=resize) - else: - frames = [ - cv2.resize( - frame, - None, - None, - fx=resize, - fy=resize, - interpolation=cv2.INTER_LINEAR, - ) - for frame in frames - ] - - # convert to torch.tensor format - if not from_PIL: - frames = frames.transpose(1, 2).transpose(1, 3).contiguous() - else: - frames = frames.transpose((0, 3, 1, 2)) - frames = torch.from_numpy(frames) - - return frames, resize - - def batched_detect_faces( - self, frames, conf_threshold=0.8, nms_threshold=0.4, use_origin_size=True - ): - """ - Arguments: - frames: a list of PIL.Image, or np.array(shape=[n, h, w, c], - type=np.uint8, BGR format). - conf_threshold: confidence threshold. - nms_threshold: nms threshold. - use_origin_size: whether to use origin size. - Returns: - final_bounding_boxes: list of np.array ([n_boxes, 5], - type=np.float32). - final_landmarks: list of np.array ([n_boxes, 10], type=np.float32). - """ - # self.t['forward_pass'].tic() - frames, self.resize = self.batched_transform(frames, use_origin_size) - frames = frames.to(device) - frames = frames - self.mean_tensor - - b_loc, b_conf, b_landmarks, priors = self.__detect_faces(frames) - - final_bounding_boxes, final_landmarks = [], [] - - # decode - priors = priors.unsqueeze(0) - b_loc = ( - batched_decode(b_loc, priors, self.cfg["variance"]) - * self.scale - / self.resize - ) - b_landmarks = ( - batched_decode_landm(b_landmarks, priors, self.cfg["variance"]) - * self.scale1 - / self.resize - ) - b_conf = b_conf[:, :, 1] - - # index for selection - b_indice = b_conf > conf_threshold - - # concat - b_loc_and_conf = torch.cat((b_loc, b_conf.unsqueeze(-1)), dim=2).float() - - for pred, landm, inds in zip(b_loc_and_conf, b_landmarks, b_indice): - # ignore low scores - pred, landm = pred[inds, :], landm[inds, :] - if pred.shape[0] == 0: - final_bounding_boxes.append(np.array([], dtype=np.float32)) - final_landmarks.append(np.array([], dtype=np.float32)) - continue - - # sort - # order = score.argsort(descending=True) - # box, landm, score = box[order], landm[order], score[order] - - # to CPU - bounding_boxes, landm = pred.cpu().numpy(), landm.cpu().numpy() - - # NMS - keep = py_cpu_nms(bounding_boxes, nms_threshold) - bounding_boxes, landmarks = bounding_boxes[keep, :], landm[keep] - - # append - final_bounding_boxes.append(bounding_boxes) - final_landmarks.append(landmarks) - # self.t['forward_pass'].toc(average=True) - # self.batch_time += self.t['forward_pass'].diff - # self.total_frame += len(frames) - # print(self.batch_time / self.total_frame) - - return final_bounding_boxes, final_landmarks +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from comfy import model_management +from PIL import Image +from torchvision.models._utils import IntermediateLayerGetter as IntermediateLayerGetter + +from hordelib.nodes.facerestore_cf.facelib.detection.align_trans import ( + get_reference_facial_points, + warp_and_crop_face, +) +from hordelib.nodes.facerestore_cf.facelib.detection.retinaface.retinaface_net import ( + FPN, + SSH, + MobileNetV1, + make_bbox_head, + make_class_head, + make_landmark_head, +) +from hordelib.nodes.facerestore_cf.facelib.detection.retinaface.retinaface_utils import ( + PriorBox, + batched_decode, + batched_decode_landm, + decode, + decode_landm, + py_cpu_nms, +) + +# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +device = model_management.get_torch_device() + + +def generate_config(network_name): + + cfg_mnet = { + "name": "mobilenet0.25", + "min_sizes": [[16, 32], [64, 128], [256, 512]], + "steps": [8, 16, 32], + "variance": [0.1, 0.2], + "clip": False, + "loc_weight": 2.0, + "gpu_train": True, + "batch_size": 32, + "ngpu": 1, + "epoch": 250, + "decay1": 190, + "decay2": 220, + "image_size": 640, + "return_layers": { + "stage1": 1, + "stage2": 2, + "stage3": 3, + }, + "in_channel": 32, + "out_channel": 64, + } + + cfg_re50 = { + "name": "Resnet50", + "min_sizes": [[16, 32], [64, 128], [256, 512]], + "steps": [8, 16, 32], + "variance": [0.1, 0.2], + "clip": False, + "loc_weight": 2.0, + "gpu_train": True, + "batch_size": 24, + "ngpu": 4, + "epoch": 100, + "decay1": 70, + "decay2": 90, + "image_size": 840, + "return_layers": { + "layer2": 1, + "layer3": 2, + "layer4": 3, + }, + "in_channel": 256, + "out_channel": 256, + } + + if network_name == "mobile0.25": + return cfg_mnet + elif network_name == "resnet50": + return cfg_re50 + else: + raise NotImplementedError(f"network_name={network_name}") + + +class RetinaFace(nn.Module): + + def __init__(self, network_name="resnet50", half=False, phase="test"): + super(RetinaFace, self).__init__() + self.half_inference = half + cfg = generate_config(network_name) + self.backbone = cfg["name"] + + self.model_name = f"retinaface_{network_name}" + self.cfg = cfg + self.phase = phase + self.target_size, self.max_size = 1600, 2150 + self.resize, self.scale, self.scale1 = 1.0, None, None + self.mean_tensor = torch.tensor([[[[104.0]], [[117.0]], [[123.0]]]]).to(device) + self.reference = get_reference_facial_points(default_square=True) + # Build network. + backbone = None + if cfg["name"] == "mobilenet0.25": + backbone = MobileNetV1() + self.body = IntermediateLayerGetter(backbone, cfg["return_layers"]) + elif cfg["name"] == "Resnet50": + import torchvision.models as models + + backbone = models.resnet50(pretrained=False) + self.body = IntermediateLayerGetter(backbone, cfg["return_layers"]) + + in_channels_stage2 = cfg["in_channel"] + in_channels_list = [ + in_channels_stage2 * 2, + in_channels_stage2 * 4, + in_channels_stage2 * 8, + ] + + out_channels = cfg["out_channel"] + self.fpn = FPN(in_channels_list, out_channels) + self.ssh1 = SSH(out_channels, out_channels) + self.ssh2 = SSH(out_channels, out_channels) + self.ssh3 = SSH(out_channels, out_channels) + + self.ClassHead = make_class_head(fpn_num=3, inchannels=cfg["out_channel"]) + self.BboxHead = make_bbox_head(fpn_num=3, inchannels=cfg["out_channel"]) + self.LandmarkHead = make_landmark_head(fpn_num=3, inchannels=cfg["out_channel"]) + + self.to(device) + self.eval() + if self.half_inference: + self.half() + + def forward(self, inputs): + out = self.body(inputs) + + if self.backbone == "mobilenet0.25" or self.backbone == "Resnet50": + out = list(out.values()) + # FPN + fpn = self.fpn(out) + + # SSH + feature1 = self.ssh1(fpn[0]) + feature2 = self.ssh2(fpn[1]) + feature3 = self.ssh3(fpn[2]) + features = [feature1, feature2, feature3] + + bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1) + classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1) + tmp = [self.LandmarkHead[i](feature) for i, feature in enumerate(features)] + ldm_regressions = torch.cat(tmp, dim=1) + + if self.phase == "train": + output = (bbox_regressions, classifications, ldm_regressions) + else: + output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions) + return output + + def __detect_faces(self, inputs): + # get scale + height, width = inputs.shape[2:] + self.scale = torch.tensor([width, height, width, height], dtype=torch.float32).to(device) + tmp = [width, height, width, height, width, height, width, height, width, height] + self.scale1 = torch.tensor(tmp, dtype=torch.float32).to(device) + + # forawrd + inputs = inputs.to(device) + if self.half_inference: + inputs = inputs.half() + loc, conf, landmarks = self(inputs) + + # get priorbox + priorbox = PriorBox(self.cfg, image_size=inputs.shape[2:]) + priors = priorbox.forward().to(device) + + return loc, conf, landmarks, priors + + # single image detection + def transform(self, image, use_origin_size): + # convert to opencv format + if isinstance(image, Image.Image): + image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) + image = image.astype(np.float32) + + # testing scale + im_size_min = np.min(image.shape[0:2]) + im_size_max = np.max(image.shape[0:2]) + resize = float(self.target_size) / float(im_size_min) + + # prevent bigger axis from being more than max_size + if np.round(resize * im_size_max) > self.max_size: + resize = float(self.max_size) / float(im_size_max) + resize = 1 if use_origin_size else resize + + # resize + if resize != 1: + image = cv2.resize(image, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR) + + # convert to torch.tensor format + # image -= (104, 117, 123) + image = image.transpose(2, 0, 1) + image = torch.from_numpy(image).unsqueeze(0) + + return image, resize + + def detect_faces( + self, + image, + conf_threshold=0.8, + nms_threshold=0.4, + use_origin_size=True, + ): + """ + Params: + imgs: BGR image + """ + image, self.resize = self.transform(image, use_origin_size) + image = image.to(device) + if self.half_inference: + image = image.half() + image = image - self.mean_tensor + + loc, conf, landmarks, priors = self.__detect_faces(image) + + boxes = decode(loc.data.squeeze(0), priors.data, self.cfg["variance"]) + boxes = boxes * self.scale / self.resize + boxes = boxes.cpu().numpy() + + scores = conf.squeeze(0).data.cpu().numpy()[:, 1] + + landmarks = decode_landm(landmarks.squeeze(0), priors, self.cfg["variance"]) + landmarks = landmarks * self.scale1 / self.resize + landmarks = landmarks.cpu().numpy() + + # ignore low scores + inds = np.where(scores > conf_threshold)[0] + boxes, landmarks, scores = boxes[inds], landmarks[inds], scores[inds] + + # sort + order = scores.argsort()[::-1] + boxes, landmarks, scores = boxes[order], landmarks[order], scores[order] + + # do NMS + bounding_boxes = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) + keep = py_cpu_nms(bounding_boxes, nms_threshold) + bounding_boxes, landmarks = bounding_boxes[keep, :], landmarks[keep] + # self.t['forward_pass'].toc() + # print(self.t['forward_pass'].average_time) + # import sys + # sys.stdout.flush() + return np.concatenate((bounding_boxes, landmarks), axis=1) + + def __align_multi(self, image, boxes, landmarks, limit=None): + + if len(boxes) < 1: + return [], [] + + if limit: + boxes = boxes[:limit] + landmarks = landmarks[:limit] + + faces = [] + for landmark in landmarks: + facial5points = [[landmark[2 * j], landmark[2 * j + 1]] for j in range(5)] + + warped_face = warp_and_crop_face(np.array(image), facial5points, self.reference, crop_size=(112, 112)) + faces.append(warped_face) + + return np.concatenate((boxes, landmarks), axis=1), faces + + def align_multi(self, img, conf_threshold=0.8, limit=None): + + rlt = self.detect_faces(img, conf_threshold=conf_threshold) + boxes, landmarks = rlt[:, 0:5], rlt[:, 5:] + + return self.__align_multi(img, boxes, landmarks, limit) + + # batched detection + def batched_transform(self, frames, use_origin_size): + """ + Arguments: + frames: a list of PIL.Image, or torch.Tensor(shape=[n, h, w, c], + type=np.float32, BGR format). + use_origin_size: whether to use origin size. + """ + from_PIL = True if isinstance(frames[0], Image.Image) else False + + # convert to opencv format + if from_PIL: + frames = [cv2.cvtColor(np.asarray(frame), cv2.COLOR_RGB2BGR) for frame in frames] + frames = np.asarray(frames, dtype=np.float32) + + # testing scale + im_size_min = np.min(frames[0].shape[0:2]) + im_size_max = np.max(frames[0].shape[0:2]) + resize = float(self.target_size) / float(im_size_min) + + # prevent bigger axis from being more than max_size + if np.round(resize * im_size_max) > self.max_size: + resize = float(self.max_size) / float(im_size_max) + resize = 1 if use_origin_size else resize + + # resize + if resize != 1: + if not from_PIL: + frames = F.interpolate(frames, scale_factor=resize) + else: + frames = [ + cv2.resize(frame, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR) + for frame in frames + ] + + # convert to torch.tensor format + if not from_PIL: + frames = frames.transpose(1, 2).transpose(1, 3).contiguous() + else: + frames = frames.transpose((0, 3, 1, 2)) + frames = torch.from_numpy(frames) + + return frames, resize + + def batched_detect_faces(self, frames, conf_threshold=0.8, nms_threshold=0.4, use_origin_size=True): + """ + Arguments: + frames: a list of PIL.Image, or np.array(shape=[n, h, w, c], + type=np.uint8, BGR format). + conf_threshold: confidence threshold. + nms_threshold: nms threshold. + use_origin_size: whether to use origin size. + Returns: + final_bounding_boxes: list of np.array ([n_boxes, 5], + type=np.float32). + final_landmarks: list of np.array ([n_boxes, 10], type=np.float32). + """ + # self.t['forward_pass'].tic() + frames, self.resize = self.batched_transform(frames, use_origin_size) + frames = frames.to(device) + frames = frames - self.mean_tensor + + b_loc, b_conf, b_landmarks, priors = self.__detect_faces(frames) + + final_bounding_boxes, final_landmarks = [], [] + + # decode + priors = priors.unsqueeze(0) + b_loc = batched_decode(b_loc, priors, self.cfg["variance"]) * self.scale / self.resize + b_landmarks = batched_decode_landm(b_landmarks, priors, self.cfg["variance"]) * self.scale1 / self.resize + b_conf = b_conf[:, :, 1] + + # index for selection + b_indice = b_conf > conf_threshold + + # concat + b_loc_and_conf = torch.cat((b_loc, b_conf.unsqueeze(-1)), dim=2).float() + + for pred, landm, inds in zip(b_loc_and_conf, b_landmarks, b_indice, strict=False): + + # ignore low scores + pred, landm = pred[inds, :], landm[inds, :] + if pred.shape[0] == 0: + final_bounding_boxes.append(np.array([], dtype=np.float32)) + final_landmarks.append(np.array([], dtype=np.float32)) + continue + + # sort + # order = score.argsort(descending=True) + # box, landm, score = box[order], landm[order], score[order] + + # to CPU + bounding_boxes, landm = pred.cpu().numpy(), landm.cpu().numpy() + + # NMS + keep = py_cpu_nms(bounding_boxes, nms_threshold) + bounding_boxes, landmarks = bounding_boxes[keep, :], landm[keep] + + # append + final_bounding_boxes.append(bounding_boxes) + final_landmarks.append(landmarks) + # self.t['forward_pass'].toc(average=True) + # self.batch_time += self.t['forward_pass'].diff + # self.total_frame += len(frames) + # print(self.batch_time / self.total_frame) + + return final_bounding_boxes, final_landmarks diff --git a/hordelib/nodes/facerestore/facelib/detection/retinaface/retinaface_net.py b/hordelib/nodes/facerestore_cf/facelib/detection/retinaface/retinaface_net.py similarity index 90% rename from hordelib/nodes/facerestore/facelib/detection/retinaface/retinaface_net.py rename to hordelib/nodes/facerestore_cf/facelib/detection/retinaface/retinaface_net.py index ab6aa82d..bd5f6816 100644 --- a/hordelib/nodes/facerestore/facelib/detection/retinaface/retinaface_net.py +++ b/hordelib/nodes/facerestore_cf/facelib/detection/retinaface/retinaface_net.py @@ -1,196 +1,200 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - - -def conv_bn(inp, oup, stride=1, leaky=0): - return nn.Sequential( - nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), - nn.LeakyReLU(negative_slope=leaky, inplace=True)) - - -def conv_bn_no_relu(inp, oup, stride): - return nn.Sequential( - nn.Conv2d(inp, oup, 3, stride, 1, bias=False), - nn.BatchNorm2d(oup), - ) - - -def conv_bn1X1(inp, oup, stride, leaky=0): - return nn.Sequential( - nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), nn.BatchNorm2d(oup), - nn.LeakyReLU(negative_slope=leaky, inplace=True)) - - -def conv_dw(inp, oup, stride, leaky=0.1): - return nn.Sequential( - nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), - nn.BatchNorm2d(inp), - nn.LeakyReLU(negative_slope=leaky, inplace=True), - nn.Conv2d(inp, oup, 1, 1, 0, bias=False), - nn.BatchNorm2d(oup), - nn.LeakyReLU(negative_slope=leaky, inplace=True), - ) - - -class SSH(nn.Module): - - def __init__(self, in_channel, out_channel): - super(SSH, self).__init__() - assert out_channel % 4 == 0 - leaky = 0 - if (out_channel <= 64): - leaky = 0.1 - self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1) - - self.conv5X5_1 = conv_bn(in_channel, out_channel // 4, stride=1, leaky=leaky) - self.conv5X5_2 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1) - - self.conv7X7_2 = conv_bn(out_channel // 4, out_channel // 4, stride=1, leaky=leaky) - self.conv7x7_3 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1) - - def forward(self, input): - conv3X3 = self.conv3X3(input) - - conv5X5_1 = self.conv5X5_1(input) - conv5X5 = self.conv5X5_2(conv5X5_1) - - conv7X7_2 = self.conv7X7_2(conv5X5_1) - conv7X7 = self.conv7x7_3(conv7X7_2) - - out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) - out = F.relu(out) - return out - - -class FPN(nn.Module): - - def __init__(self, in_channels_list, out_channels): - super(FPN, self).__init__() - leaky = 0 - if (out_channels <= 64): - leaky = 0.1 - self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride=1, leaky=leaky) - self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride=1, leaky=leaky) - self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride=1, leaky=leaky) - - self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky) - self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky) - - def forward(self, input): - # names = list(input.keys()) - # input = list(input.values()) - - output1 = self.output1(input[0]) - output2 = self.output2(input[1]) - output3 = self.output3(input[2]) - - up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode='nearest') - output2 = output2 + up3 - output2 = self.merge2(output2) - - up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode='nearest') - output1 = output1 + up2 - output1 = self.merge1(output1) - - out = [output1, output2, output3] - return out - - -class MobileNetV1(nn.Module): - - def __init__(self): - super(MobileNetV1, self).__init__() - self.stage1 = nn.Sequential( - conv_bn(3, 8, 2, leaky=0.1), # 3 - conv_dw(8, 16, 1), # 7 - conv_dw(16, 32, 2), # 11 - conv_dw(32, 32, 1), # 19 - conv_dw(32, 64, 2), # 27 - conv_dw(64, 64, 1), # 43 - ) - self.stage2 = nn.Sequential( - conv_dw(64, 128, 2), # 43 + 16 = 59 - conv_dw(128, 128, 1), # 59 + 32 = 91 - conv_dw(128, 128, 1), # 91 + 32 = 123 - conv_dw(128, 128, 1), # 123 + 32 = 155 - conv_dw(128, 128, 1), # 155 + 32 = 187 - conv_dw(128, 128, 1), # 187 + 32 = 219 - ) - self.stage3 = nn.Sequential( - conv_dw(128, 256, 2), # 219 +3 2 = 241 - conv_dw(256, 256, 1), # 241 + 64 = 301 - ) - self.avg = nn.AdaptiveAvgPool2d((1, 1)) - self.fc = nn.Linear(256, 1000) - - def forward(self, x): - x = self.stage1(x) - x = self.stage2(x) - x = self.stage3(x) - x = self.avg(x) - # x = self.model(x) - x = x.view(-1, 256) - x = self.fc(x) - return x - - -class ClassHead(nn.Module): - - def __init__(self, inchannels=512, num_anchors=3): - super(ClassHead, self).__init__() - self.num_anchors = num_anchors - self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0) - - def forward(self, x): - out = self.conv1x1(x) - out = out.permute(0, 2, 3, 1).contiguous() - - return out.view(out.shape[0], -1, 2) - - -class BboxHead(nn.Module): - - def __init__(self, inchannels=512, num_anchors=3): - super(BboxHead, self).__init__() - self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0) - - def forward(self, x): - out = self.conv1x1(x) - out = out.permute(0, 2, 3, 1).contiguous() - - return out.view(out.shape[0], -1, 4) - - -class LandmarkHead(nn.Module): - - def __init__(self, inchannels=512, num_anchors=3): - super(LandmarkHead, self).__init__() - self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0) - - def forward(self, x): - out = self.conv1x1(x) - out = out.permute(0, 2, 3, 1).contiguous() - - return out.view(out.shape[0], -1, 10) - - -def make_class_head(fpn_num=3, inchannels=64, anchor_num=2): - classhead = nn.ModuleList() - for i in range(fpn_num): - classhead.append(ClassHead(inchannels, anchor_num)) - return classhead - - -def make_bbox_head(fpn_num=3, inchannels=64, anchor_num=2): - bboxhead = nn.ModuleList() - for i in range(fpn_num): - bboxhead.append(BboxHead(inchannels, anchor_num)) - return bboxhead - - -def make_landmark_head(fpn_num=3, inchannels=64, anchor_num=2): - landmarkhead = nn.ModuleList() - for i in range(fpn_num): - landmarkhead.append(LandmarkHead(inchannels, anchor_num)) - return landmarkhead +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def conv_bn(inp, oup, stride=1, leaky=0): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True), + ) + + +def conv_bn_no_relu(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + ) + + +def conv_bn1X1(inp, oup, stride, leaky=0): + return nn.Sequential( + nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), + nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True), + ) + + +def conv_dw(inp, oup, stride, leaky=0.1): + return nn.Sequential( + nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), + nn.BatchNorm2d(inp), + nn.LeakyReLU(negative_slope=leaky, inplace=True), + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True), + ) + + +class SSH(nn.Module): + + def __init__(self, in_channel, out_channel): + super(SSH, self).__init__() + assert out_channel % 4 == 0 + leaky = 0 + if out_channel <= 64: + leaky = 0.1 + self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1) + + self.conv5X5_1 = conv_bn(in_channel, out_channel // 4, stride=1, leaky=leaky) + self.conv5X5_2 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1) + + self.conv7X7_2 = conv_bn(out_channel // 4, out_channel // 4, stride=1, leaky=leaky) + self.conv7x7_3 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1) + + def forward(self, input): + conv3X3 = self.conv3X3(input) + + conv5X5_1 = self.conv5X5_1(input) + conv5X5 = self.conv5X5_2(conv5X5_1) + + conv7X7_2 = self.conv7X7_2(conv5X5_1) + conv7X7 = self.conv7x7_3(conv7X7_2) + + out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) + out = F.relu(out) + return out + + +class FPN(nn.Module): + + def __init__(self, in_channels_list, out_channels): + super(FPN, self).__init__() + leaky = 0 + if out_channels <= 64: + leaky = 0.1 + self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride=1, leaky=leaky) + self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride=1, leaky=leaky) + self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride=1, leaky=leaky) + + self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky) + self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky) + + def forward(self, input): + # names = list(input.keys()) + # input = list(input.values()) + + output1 = self.output1(input[0]) + output2 = self.output2(input[1]) + output3 = self.output3(input[2]) + + up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode="nearest") + output2 = output2 + up3 + output2 = self.merge2(output2) + + up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode="nearest") + output1 = output1 + up2 + output1 = self.merge1(output1) + + out = [output1, output2, output3] + return out + + +class MobileNetV1(nn.Module): + + def __init__(self): + super(MobileNetV1, self).__init__() + self.stage1 = nn.Sequential( + conv_bn(3, 8, 2, leaky=0.1), # 3 + conv_dw(8, 16, 1), # 7 + conv_dw(16, 32, 2), # 11 + conv_dw(32, 32, 1), # 19 + conv_dw(32, 64, 2), # 27 + conv_dw(64, 64, 1), # 43 + ) + self.stage2 = nn.Sequential( + conv_dw(64, 128, 2), # 43 + 16 = 59 + conv_dw(128, 128, 1), # 59 + 32 = 91 + conv_dw(128, 128, 1), # 91 + 32 = 123 + conv_dw(128, 128, 1), # 123 + 32 = 155 + conv_dw(128, 128, 1), # 155 + 32 = 187 + conv_dw(128, 128, 1), # 187 + 32 = 219 + ) + self.stage3 = nn.Sequential( + conv_dw(128, 256, 2), # 219 +3 2 = 241 + conv_dw(256, 256, 1), # 241 + 64 = 301 + ) + self.avg = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(256, 1000) + + def forward(self, x): + x = self.stage1(x) + x = self.stage2(x) + x = self.stage3(x) + x = self.avg(x) + # x = self.model(x) + x = x.view(-1, 256) + x = self.fc(x) + return x + + +class ClassHead(nn.Module): + + def __init__(self, inchannels=512, num_anchors=3): + super(ClassHead, self).__init__() + self.num_anchors = num_anchors + self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0) + + def forward(self, x): + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + + return out.view(out.shape[0], -1, 2) + + +class BboxHead(nn.Module): + + def __init__(self, inchannels=512, num_anchors=3): + super(BboxHead, self).__init__() + self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0) + + def forward(self, x): + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + + return out.view(out.shape[0], -1, 4) + + +class LandmarkHead(nn.Module): + + def __init__(self, inchannels=512, num_anchors=3): + super(LandmarkHead, self).__init__() + self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0) + + def forward(self, x): + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + + return out.view(out.shape[0], -1, 10) + + +def make_class_head(fpn_num=3, inchannels=64, anchor_num=2): + classhead = nn.ModuleList() + for i in range(fpn_num): + classhead.append(ClassHead(inchannels, anchor_num)) + return classhead + + +def make_bbox_head(fpn_num=3, inchannels=64, anchor_num=2): + bboxhead = nn.ModuleList() + for i in range(fpn_num): + bboxhead.append(BboxHead(inchannels, anchor_num)) + return bboxhead + + +def make_landmark_head(fpn_num=3, inchannels=64, anchor_num=2): + landmarkhead = nn.ModuleList() + for i in range(fpn_num): + landmarkhead.append(LandmarkHead(inchannels, anchor_num)) + return landmarkhead diff --git a/hordelib/nodes/facerestore/facelib/detection/retinaface/retinaface_utils.py b/hordelib/nodes/facerestore_cf/facelib/detection/retinaface/retinaface_utils.py similarity index 92% rename from hordelib/nodes/facerestore/facelib/detection/retinaface/retinaface_utils.py rename to hordelib/nodes/facerestore_cf/facelib/detection/retinaface/retinaface_utils.py index 8c357757..1900c0b7 100644 --- a/hordelib/nodes/facerestore/facelib/detection/retinaface/retinaface_utils.py +++ b/hordelib/nodes/facerestore_cf/facelib/detection/retinaface/retinaface_utils.py @@ -1,421 +1,420 @@ -import numpy as np -import torch -import torchvision -from itertools import product as product -from math import ceil - - -class PriorBox(object): - - def __init__(self, cfg, image_size=None, phase='train'): - super(PriorBox, self).__init__() - self.min_sizes = cfg['min_sizes'] - self.steps = cfg['steps'] - self.clip = cfg['clip'] - self.image_size = image_size - self.feature_maps = [[ceil(self.image_size[0] / step), ceil(self.image_size[1] / step)] for step in self.steps] - self.name = 's' - - def forward(self): - anchors = [] - for k, f in enumerate(self.feature_maps): - min_sizes = self.min_sizes[k] - for i, j in product(range(f[0]), range(f[1])): - for min_size in min_sizes: - s_kx = min_size / self.image_size[1] - s_ky = min_size / self.image_size[0] - dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]] - dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]] - for cy, cx in product(dense_cy, dense_cx): - anchors += [cx, cy, s_kx, s_ky] - - # back to torch land - output = torch.Tensor(anchors).view(-1, 4) - if self.clip: - output.clamp_(max=1, min=0) - return output - - -def py_cpu_nms(dets, thresh): - """Pure Python NMS baseline.""" - keep = torchvision.ops.nms( - boxes=torch.Tensor(dets[:, :4]), - scores=torch.Tensor(dets[:, 4]), - iou_threshold=thresh, - ) - - return list(keep) - - -def point_form(boxes): - """ Convert prior_boxes to (xmin, ymin, xmax, ymax) - representation for comparison to point form ground truth data. - Args: - boxes: (tensor) center-size default boxes from priorbox layers. - Return: - boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. - """ - return torch.cat( - ( - boxes[:, :2] - boxes[:, 2:] / 2, # xmin, ymin - boxes[:, :2] + boxes[:, 2:] / 2), - 1) # xmax, ymax - - -def center_size(boxes): - """ Convert prior_boxes to (cx, cy, w, h) - representation for comparison to center-size form ground truth data. - Args: - boxes: (tensor) point_form boxes - Return: - boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. - """ - return torch.cat( - (boxes[:, 2:] + boxes[:, :2]) / 2, # cx, cy - boxes[:, 2:] - boxes[:, :2], - 1) # w, h - - -def intersect(box_a, box_b): - """ We resize both tensors to [A,B,2] without new malloc: - [A,2] -> [A,1,2] -> [A,B,2] - [B,2] -> [1,B,2] -> [A,B,2] - Then we compute the area of intersect between box_a and box_b. - Args: - box_a: (tensor) bounding boxes, Shape: [A,4]. - box_b: (tensor) bounding boxes, Shape: [B,4]. - Return: - (tensor) intersection area, Shape: [A,B]. - """ - A = box_a.size(0) - B = box_b.size(0) - max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) - min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2)) - inter = torch.clamp((max_xy - min_xy), min=0) - return inter[:, :, 0] * inter[:, :, 1] - - -def jaccard(box_a, box_b): - """Compute the jaccard overlap of two sets of boxes. The jaccard overlap - is simply the intersection over union of two boxes. Here we operate on - ground truth boxes and default boxes. - E.g.: - A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) - Args: - box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] - box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] - Return: - jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] - """ - inter = intersect(box_a, box_b) - area_a = ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] - area_b = ((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] - union = area_a + area_b - inter - return inter / union # [A,B] - - -def matrix_iou(a, b): - """ - return iou of a and b, numpy version for data augenmentation - """ - lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) - rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) - - area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) - area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) - area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) - return area_i / (area_a[:, np.newaxis] + area_b - area_i) - - -def matrix_iof(a, b): - """ - return iof of a and b, numpy version for data augenmentation - """ - lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) - rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) - - area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) - area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) - return area_i / np.maximum(area_a[:, np.newaxis], 1) - - -def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx): - """Match each prior box with the ground truth box of the highest jaccard - overlap, encode the bounding boxes, then return the matched indices - corresponding to both confidence and location preds. - Args: - threshold: (float) The overlap threshold used when matching boxes. - truths: (tensor) Ground truth boxes, Shape: [num_obj, 4]. - priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4]. - variances: (tensor) Variances corresponding to each prior coord, - Shape: [num_priors, 4]. - labels: (tensor) All the class labels for the image, Shape: [num_obj]. - landms: (tensor) Ground truth landms, Shape [num_obj, 10]. - loc_t: (tensor) Tensor to be filled w/ encoded location targets. - conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds. - landm_t: (tensor) Tensor to be filled w/ encoded landm targets. - idx: (int) current batch index - Return: - The matched indices corresponding to 1)location 2)confidence - 3)landm preds. - """ - # jaccard index - overlaps = jaccard(truths, point_form(priors)) - # (Bipartite Matching) - # [1,num_objects] best prior for each ground truth - best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) - - # ignore hard gt - valid_gt_idx = best_prior_overlap[:, 0] >= 0.2 - best_prior_idx_filter = best_prior_idx[valid_gt_idx, :] - if best_prior_idx_filter.shape[0] <= 0: - loc_t[idx] = 0 - conf_t[idx] = 0 - return - - # [1,num_priors] best ground truth for each prior - best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) - best_truth_idx.squeeze_(0) - best_truth_overlap.squeeze_(0) - best_prior_idx.squeeze_(1) - best_prior_idx_filter.squeeze_(1) - best_prior_overlap.squeeze_(1) - best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior - # TODO refactor: index best_prior_idx with long tensor - # ensure every gt matches with its prior of max overlap - for j in range(best_prior_idx.size(0)): # 判别此anchor是预测哪一个boxes - best_truth_idx[best_prior_idx[j]] = j - matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来 - conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每一个anchor对应的label取出来 - conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本 - loc = encode(matches, priors, variances) - - matches_landm = landms[best_truth_idx] - landm = encode_landm(matches_landm, priors, variances) - loc_t[idx] = loc # [num_priors,4] encoded offsets to learn - conf_t[idx] = conf # [num_priors] top class label for each prior - landm_t[idx] = landm - - -def encode(matched, priors, variances): - """Encode the variances from the priorbox layers into the ground truth boxes - we have matched (based on jaccard overlap) with the prior boxes. - Args: - matched: (tensor) Coords of ground truth for each prior in point-form - Shape: [num_priors, 4]. - priors: (tensor) Prior boxes in center-offset form - Shape: [num_priors,4]. - variances: (list[float]) Variances of priorboxes - Return: - encoded boxes (tensor), Shape: [num_priors, 4] - """ - - # dist b/t match center and prior's center - g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2] - # encode variance - g_cxcy /= (variances[0] * priors[:, 2:]) - # match wh / prior wh - g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] - g_wh = torch.log(g_wh) / variances[1] - # return target for smooth_l1_loss - return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] - - -def encode_landm(matched, priors, variances): - """Encode the variances from the priorbox layers into the ground truth boxes - we have matched (based on jaccard overlap) with the prior boxes. - Args: - matched: (tensor) Coords of ground truth for each prior in point-form - Shape: [num_priors, 10]. - priors: (tensor) Prior boxes in center-offset form - Shape: [num_priors,4]. - variances: (list[float]) Variances of priorboxes - Return: - encoded landm (tensor), Shape: [num_priors, 10] - """ - - # dist b/t match center and prior's center - matched = torch.reshape(matched, (matched.size(0), 5, 2)) - priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) - priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) - priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) - priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) - priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2) - g_cxcy = matched[:, :, :2] - priors[:, :, :2] - # encode variance - g_cxcy /= (variances[0] * priors[:, :, 2:]) - # g_cxcy /= priors[:, :, 2:] - g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1) - # return target for smooth_l1_loss - return g_cxcy - - -# Adapted from https://github.com/Hakuyume/chainer-ssd -def decode(loc, priors, variances): - """Decode locations from predictions using priors to undo - the encoding we did for offset regression at train time. - Args: - loc (tensor): location predictions for loc layers, - Shape: [num_priors,4] - priors (tensor): Prior boxes in center-offset form. - Shape: [num_priors,4]. - variances: (list[float]) Variances of priorboxes - Return: - decoded bounding box predictions - """ - - boxes = torch.cat((priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], - priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) - boxes[:, :2] -= boxes[:, 2:] / 2 - boxes[:, 2:] += boxes[:, :2] - return boxes - - -def decode_landm(pre, priors, variances): - """Decode landm from predictions using priors to undo - the encoding we did for offset regression at train time. - Args: - pre (tensor): landm predictions for loc layers, - Shape: [num_priors,10] - priors (tensor): Prior boxes in center-offset form. - Shape: [num_priors,4]. - variances: (list[float]) Variances of priorboxes - Return: - decoded landm predictions - """ - tmp = ( - priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:], - priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:], - priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:], - priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:], - priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:], - ) - landms = torch.cat(tmp, dim=1) - return landms - - -def batched_decode(b_loc, priors, variances): - """Decode locations from predictions using priors to undo - the encoding we did for offset regression at train time. - Args: - b_loc (tensor): location predictions for loc layers, - Shape: [num_batches,num_priors,4] - priors (tensor): Prior boxes in center-offset form. - Shape: [1,num_priors,4]. - variances: (list[float]) Variances of priorboxes - Return: - decoded bounding box predictions - """ - boxes = ( - priors[:, :, :2] + b_loc[:, :, :2] * variances[0] * priors[:, :, 2:], - priors[:, :, 2:] * torch.exp(b_loc[:, :, 2:] * variances[1]), - ) - boxes = torch.cat(boxes, dim=2) - - boxes[:, :, :2] -= boxes[:, :, 2:] / 2 - boxes[:, :, 2:] += boxes[:, :, :2] - return boxes - - -def batched_decode_landm(pre, priors, variances): - """Decode landm from predictions using priors to undo - the encoding we did for offset regression at train time. - Args: - pre (tensor): landm predictions for loc layers, - Shape: [num_batches,num_priors,10] - priors (tensor): Prior boxes in center-offset form. - Shape: [1,num_priors,4]. - variances: (list[float]) Variances of priorboxes - Return: - decoded landm predictions - """ - landms = ( - priors[:, :, :2] + pre[:, :, :2] * variances[0] * priors[:, :, 2:], - priors[:, :, :2] + pre[:, :, 2:4] * variances[0] * priors[:, :, 2:], - priors[:, :, :2] + pre[:, :, 4:6] * variances[0] * priors[:, :, 2:], - priors[:, :, :2] + pre[:, :, 6:8] * variances[0] * priors[:, :, 2:], - priors[:, :, :2] + pre[:, :, 8:10] * variances[0] * priors[:, :, 2:], - ) - landms = torch.cat(landms, dim=2) - return landms - - -def log_sum_exp(x): - """Utility function for computing log_sum_exp while determining - This will be used to determine unaveraged confidence loss across - all examples in a batch. - Args: - x (Variable(tensor)): conf_preds from conf layers - """ - x_max = x.data.max() - return torch.log(torch.sum(torch.exp(x - x_max), 1, keepdim=True)) + x_max - - -# Original author: Francisco Massa: -# https://github.com/fmassa/object-detection.torch -# Ported to PyTorch by Max deGroot (02/01/2017) -def nms(boxes, scores, overlap=0.5, top_k=200): - """Apply non-maximum suppression at test time to avoid detecting too many - overlapping bounding boxes for a given object. - Args: - boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. - scores: (tensor) The class predscores for the img, Shape:[num_priors]. - overlap: (float) The overlap thresh for suppressing unnecessary boxes. - top_k: (int) The Maximum number of box preds to consider. - Return: - The indices of the kept boxes with respect to num_priors. - """ - - keep = torch.Tensor(scores.size(0)).fill_(0).long() - if boxes.numel() == 0: - return keep - x1 = boxes[:, 0] - y1 = boxes[:, 1] - x2 = boxes[:, 2] - y2 = boxes[:, 3] - area = torch.mul(x2 - x1, y2 - y1) - v, idx = scores.sort(0) # sort in ascending order - # I = I[v >= 0.01] - idx = idx[-top_k:] # indices of the top-k largest vals - xx1 = boxes.new() - yy1 = boxes.new() - xx2 = boxes.new() - yy2 = boxes.new() - w = boxes.new() - h = boxes.new() - - # keep = torch.Tensor() - count = 0 - while idx.numel() > 0: - i = idx[-1] # index of current largest val - # keep.append(i) - keep[count] = i - count += 1 - if idx.size(0) == 1: - break - idx = idx[:-1] # remove kept element from view - # load bboxes of next highest vals - torch.index_select(x1, 0, idx, out=xx1) - torch.index_select(y1, 0, idx, out=yy1) - torch.index_select(x2, 0, idx, out=xx2) - torch.index_select(y2, 0, idx, out=yy2) - # store element-wise max with next highest score - xx1 = torch.clamp(xx1, min=x1[i]) - yy1 = torch.clamp(yy1, min=y1[i]) - xx2 = torch.clamp(xx2, max=x2[i]) - yy2 = torch.clamp(yy2, max=y2[i]) - w.resize_as_(xx2) - h.resize_as_(yy2) - w = xx2 - xx1 - h = yy2 - yy1 - # check sizes of xx1 and xx2.. after each iteration - w = torch.clamp(w, min=0.0) - h = torch.clamp(h, min=0.0) - inter = w * h - # IoU = i / (area(a) + area(b) - i) - rem_areas = torch.index_select(area, 0, idx) # load remaining areas) - union = (rem_areas - inter) + area[i] - IoU = inter / union # store result in iou - # keep only elements with an IoU <= overlap - idx = idx[IoU.le(overlap)] - return keep, count +from itertools import product as product +from math import ceil + +import numpy as np +import torch +import torchvision + + +class PriorBox: + + def __init__(self, cfg, image_size=None, phase="train"): + super(PriorBox, self).__init__() + self.min_sizes = cfg["min_sizes"] + self.steps = cfg["steps"] + self.clip = cfg["clip"] + self.image_size = image_size + self.feature_maps = [[ceil(self.image_size[0] / step), ceil(self.image_size[1] / step)] for step in self.steps] + self.name = "s" + + def forward(self): + anchors = [] + for k, f in enumerate(self.feature_maps): + min_sizes = self.min_sizes[k] + for i, j in product(range(f[0]), range(f[1])): + for min_size in min_sizes: + s_kx = min_size / self.image_size[1] + s_ky = min_size / self.image_size[0] + dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]] + dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]] + for cy, cx in product(dense_cy, dense_cx): + anchors += [cx, cy, s_kx, s_ky] + + # back to torch land + output = torch.Tensor(anchors).view(-1, 4) + if self.clip: + output.clamp_(max=1, min=0) + return output + + +def py_cpu_nms(dets, thresh): + """Pure Python NMS baseline.""" + keep = torchvision.ops.nms( + boxes=torch.Tensor(dets[:, :4]), + scores=torch.Tensor(dets[:, 4]), + iou_threshold=thresh, + ) + + return list(keep) + + +def point_form(boxes): + """Convert prior_boxes to (xmin, ymin, xmax, ymax) + representation for comparison to point form ground truth data. + Args: + boxes: (tensor) center-size default boxes from priorbox layers. + Return: + boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. + """ + return torch.cat((boxes[:, :2] - boxes[:, 2:] / 2, boxes[:, :2] + boxes[:, 2:] / 2), 1) # xmin, ymin # xmax, ymax + + +def center_size(boxes): + """Convert prior_boxes to (cx, cy, w, h) + representation for comparison to center-size form ground truth data. + Args: + boxes: (tensor) point_form boxes + Return: + boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. + """ + return torch.cat((boxes[:, 2:] + boxes[:, :2]) / 2, boxes[:, 2:] - boxes[:, :2], 1) # cx, cy # w, h + + +def intersect(box_a, box_b): + """We resize both tensors to [A,B,2] without new malloc: + [A,2] -> [A,1,2] -> [A,B,2] + [B,2] -> [1,B,2] -> [A,B,2] + Then we compute the area of intersect between box_a and box_b. + Args: + box_a: (tensor) bounding boxes, Shape: [A,4]. + box_b: (tensor) bounding boxes, Shape: [B,4]. + Return: + (tensor) intersection area, Shape: [A,B]. + """ + A = box_a.size(0) + B = box_b.size(0) + max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) + min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2)) + inter = torch.clamp((max_xy - min_xy), min=0) + return inter[:, :, 0] * inter[:, :, 1] + + +def jaccard(box_a, box_b): + """Compute the jaccard overlap of two sets of boxes. The jaccard overlap + is simply the intersection over union of two boxes. Here we operate on + ground truth boxes and default boxes. + E.g.: + A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) + Args: + box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] + box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] + Return: + jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] + """ + inter = intersect(box_a, box_b) + area_a = ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] + area_b = ((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] + union = area_a + area_b - inter + return inter / union # [A,B] + + +def matrix_iou(a, b): + """ + return iou of a and b, numpy version for data augenmentation + """ + lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) + return area_i / (area_a[:, np.newaxis] + area_b - area_i) + + +def matrix_iof(a, b): + """ + return iof of a and b, numpy version for data augenmentation + """ + lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + return area_i / np.maximum(area_a[:, np.newaxis], 1) + + +def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx): + """Match each prior box with the ground truth box of the highest jaccard + overlap, encode the bounding boxes, then return the matched indices + corresponding to both confidence and location preds. + Args: + threshold: (float) The overlap threshold used when matching boxes. + truths: (tensor) Ground truth boxes, Shape: [num_obj, 4]. + priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4]. + variances: (tensor) Variances corresponding to each prior coord, + Shape: [num_priors, 4]. + labels: (tensor) All the class labels for the image, Shape: [num_obj]. + landms: (tensor) Ground truth landms, Shape [num_obj, 10]. + loc_t: (tensor) Tensor to be filled w/ encoded location targets. + conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds. + landm_t: (tensor) Tensor to be filled w/ encoded landm targets. + idx: (int) current batch index + Return: + The matched indices corresponding to 1)location 2)confidence + 3)landm preds. + """ + # jaccard index + overlaps = jaccard(truths, point_form(priors)) + # (Bipartite Matching) + # [1,num_objects] best prior for each ground truth + best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) + + # ignore hard gt + valid_gt_idx = best_prior_overlap[:, 0] >= 0.2 + best_prior_idx_filter = best_prior_idx[valid_gt_idx, :] + if best_prior_idx_filter.shape[0] <= 0: + loc_t[idx] = 0 + conf_t[idx] = 0 + return + + # [1,num_priors] best ground truth for each prior + best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) + best_truth_idx.squeeze_(0) + best_truth_overlap.squeeze_(0) + best_prior_idx.squeeze_(1) + best_prior_idx_filter.squeeze_(1) + best_prior_overlap.squeeze_(1) + best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior + # TODO refactor: index best_prior_idx with long tensor + # ensure every gt matches with its prior of max overlap + for j in range(best_prior_idx.size(0)): # 判别此anchor是预测哪一个boxes + best_truth_idx[best_prior_idx[j]] = j + matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来 + conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每一个anchor对应的label取出来 + conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本 + loc = encode(matches, priors, variances) + + matches_landm = landms[best_truth_idx] + landm = encode_landm(matches_landm, priors, variances) + loc_t[idx] = loc # [num_priors,4] encoded offsets to learn + conf_t[idx] = conf # [num_priors] top class label for each prior + landm_t[idx] = landm + + +def encode(matched, priors, variances): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + Args: + matched: (tensor) Coords of ground truth for each prior in point-form + Shape: [num_priors, 4]. + priors: (tensor) Prior boxes in center-offset form + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + encoded boxes (tensor), Shape: [num_priors, 4] + """ + + # dist b/t match center and prior's center + g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2] + # encode variance + g_cxcy /= variances[0] * priors[:, 2:] + # match wh / prior wh + g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] + g_wh = torch.log(g_wh) / variances[1] + # return target for smooth_l1_loss + return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] + + +def encode_landm(matched, priors, variances): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + Args: + matched: (tensor) Coords of ground truth for each prior in point-form + Shape: [num_priors, 10]. + priors: (tensor) Prior boxes in center-offset form + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + encoded landm (tensor), Shape: [num_priors, 10] + """ + + # dist b/t match center and prior's center + matched = torch.reshape(matched, (matched.size(0), 5, 2)) + priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2) + g_cxcy = matched[:, :, :2] - priors[:, :, :2] + # encode variance + g_cxcy /= variances[0] * priors[:, :, 2:] + # g_cxcy /= priors[:, :, 2:] + g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1) + # return target for smooth_l1_loss + return g_cxcy + + +# Adapted from https://github.com/Hakuyume/chainer-ssd +def decode(loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = torch.cat( + ( + priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1]), + ), + 1, + ) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + + +def decode_landm(pre, priors, variances): + """Decode landm from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + pre (tensor): landm predictions for loc layers, + Shape: [num_priors,10] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded landm predictions + """ + tmp = ( + priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:], + ) + landms = torch.cat(tmp, dim=1) + return landms + + +def batched_decode(b_loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + b_loc (tensor): location predictions for loc layers, + Shape: [num_batches,num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [1,num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + boxes = ( + priors[:, :, :2] + b_loc[:, :, :2] * variances[0] * priors[:, :, 2:], + priors[:, :, 2:] * torch.exp(b_loc[:, :, 2:] * variances[1]), + ) + boxes = torch.cat(boxes, dim=2) + + boxes[:, :, :2] -= boxes[:, :, 2:] / 2 + boxes[:, :, 2:] += boxes[:, :, :2] + return boxes + + +def batched_decode_landm(pre, priors, variances): + """Decode landm from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + pre (tensor): landm predictions for loc layers, + Shape: [num_batches,num_priors,10] + priors (tensor): Prior boxes in center-offset form. + Shape: [1,num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded landm predictions + """ + landms = ( + priors[:, :, :2] + pre[:, :, :2] * variances[0] * priors[:, :, 2:], + priors[:, :, :2] + pre[:, :, 2:4] * variances[0] * priors[:, :, 2:], + priors[:, :, :2] + pre[:, :, 4:6] * variances[0] * priors[:, :, 2:], + priors[:, :, :2] + pre[:, :, 6:8] * variances[0] * priors[:, :, 2:], + priors[:, :, :2] + pre[:, :, 8:10] * variances[0] * priors[:, :, 2:], + ) + landms = torch.cat(landms, dim=2) + return landms + + +def log_sum_exp(x): + """Utility function for computing log_sum_exp while determining + This will be used to determine unaveraged confidence loss across + all examples in a batch. + Args: + x (Variable(tensor)): conf_preds from conf layers + """ + x_max = x.data.max() + return torch.log(torch.sum(torch.exp(x - x_max), 1, keepdim=True)) + x_max + + +# Original author: Francisco Massa: +# https://github.com/fmassa/object-detection.torch +# Ported to PyTorch by Max deGroot (02/01/2017) +def nms(boxes, scores, overlap=0.5, top_k=200): + """Apply non-maximum suppression at test time to avoid detecting too many + overlapping bounding boxes for a given object. + Args: + boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. + scores: (tensor) The class predscores for the img, Shape:[num_priors]. + overlap: (float) The overlap thresh for suppressing unnecessary boxes. + top_k: (int) The Maximum number of box preds to consider. + Return: + The indices of the kept boxes with respect to num_priors. + """ + + keep = torch.Tensor(scores.size(0)).fill_(0).long() + if boxes.numel() == 0: + return keep + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + area = torch.mul(x2 - x1, y2 - y1) + v, idx = scores.sort(0) # sort in ascending order + # I = I[v >= 0.01] + idx = idx[-top_k:] # indices of the top-k largest vals + xx1 = boxes.new() + yy1 = boxes.new() + xx2 = boxes.new() + yy2 = boxes.new() + w = boxes.new() + h = boxes.new() + + # keep = torch.Tensor() + count = 0 + while idx.numel() > 0: + i = idx[-1] # index of current largest val + # keep.append(i) + keep[count] = i + count += 1 + if idx.size(0) == 1: + break + idx = idx[:-1] # remove kept element from view + # load bboxes of next highest vals + torch.index_select(x1, 0, idx, out=xx1) + torch.index_select(y1, 0, idx, out=yy1) + torch.index_select(x2, 0, idx, out=xx2) + torch.index_select(y2, 0, idx, out=yy2) + # store element-wise max with next highest score + xx1 = torch.clamp(xx1, min=x1[i]) + yy1 = torch.clamp(yy1, min=y1[i]) + xx2 = torch.clamp(xx2, max=x2[i]) + yy2 = torch.clamp(yy2, max=y2[i]) + w.resize_as_(xx2) + h.resize_as_(yy2) + w = xx2 - xx1 + h = yy2 - yy1 + # check sizes of xx1 and xx2.. after each iteration + w = torch.clamp(w, min=0.0) + h = torch.clamp(h, min=0.0) + inter = w * h + # IoU = i / (area(a) + area(b) - i) + rem_areas = torch.index_select(area, 0, idx) # load remaining areas) + union = (rem_areas - inter) + area[i] + IoU = inter / union # store result in iou + # keep only elements with an IoU <= overlap + idx = idx[IoU.le(overlap)] + return keep, count diff --git a/hordelib/nodes/facerestore/facelib/detection/yolov5face/models/__init__.py b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/__init__.py similarity index 100% rename from hordelib/nodes/facerestore/facelib/detection/yolov5face/models/__init__.py rename to hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/__init__.py diff --git a/hordelib/nodes/facerestore/facelib/detection/yolov5face/face_detector.py b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/face_detector.py similarity index 71% rename from hordelib/nodes/facerestore/facelib/detection/yolov5face/face_detector.py rename to hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/face_detector.py index 5ea44d65..7dd6378c 100644 --- a/hordelib/nodes/facerestore/facelib/detection/yolov5face/face_detector.py +++ b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/face_detector.py @@ -1,175 +1,148 @@ -import copy -import os -from pathlib import Path - -import cv2 -import numpy as np -import torch -from torch import nn - -from hordelib.nodes.facerestore.facelib.detection.yolov5face.models.common import Conv -from hordelib.nodes.facerestore.facelib.detection.yolov5face.models.yolo import Model -from hordelib.nodes.facerestore.facelib.detection.yolov5face.utils.datasets import letterbox -from hordelib.nodes.facerestore.facelib.detection.yolov5face.utils.general import ( - check_img_size, - non_max_suppression_face, - scale_coords, - scale_coords_landmarks, -) - - -def is_high_version(): - from packaging import version - try: - torch_v = version.parse(torch.__version__) - return torch_v > version.parse("1.9.0") - except Exception: - return True - - -def isListempty(inList): - if isinstance(inList, list): # Is a list - return all(map(isListempty, inList)) - return False # Not a list - - -class YoloDetector: - def __init__( - self, - config_name, - min_face=10, - target_size=None, - device="cuda", - ): - """ - config_name: name of .yaml config with network configuration from models/ folder. - min_face : minimal face size in pixels. - target_size : target size of smaller image axis (choose lower for faster work). e.g. 480, 720, 1080. - None for original resolution. - """ - self._class_path = Path(__file__).parent.absolute() - self.target_size = target_size - self.min_face = min_face - self.detector = Model(cfg=config_name) - self.device = device - - def _preprocess(self, imgs): - """ - Preprocessing image before passing through the network. Resize and conversion to torch tensor. - """ - pp_imgs = [] - for img in imgs: - h0, w0 = img.shape[:2] # orig hw - if self.target_size: - r = self.target_size / min(h0, w0) # resize image to img_size - if r < 1: - img = cv2.resize( - img, (int(w0 * r), int(h0 * r)), interpolation=cv2.INTER_LINEAR - ) - - imgsz = check_img_size( - max(img.shape[:2]), s=self.detector.stride.max() - ) # check img_size - img = letterbox(img, new_shape=imgsz)[0] - pp_imgs.append(img) - pp_imgs = np.array(pp_imgs) - pp_imgs = pp_imgs.transpose(0, 3, 1, 2) - pp_imgs = torch.from_numpy(pp_imgs).to(self.device) - pp_imgs = pp_imgs.float() # uint8 to fp16/32 - return pp_imgs / 255.0 # 0 - 255 to 0.0 - 1.0 - - def _postprocess(self, imgs, origimgs, pred, conf_thres, iou_thres): - """ - Postprocessing of raw pytorch model output. - Returns: - bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2. - points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners). - """ - bboxes = [[] for _ in range(len(origimgs))] - landmarks = [[] for _ in range(len(origimgs))] - - pred = non_max_suppression_face(pred, conf_thres, iou_thres) - - for image_id, origimg in enumerate(origimgs): - img_shape = origimg.shape - image_height, image_width = img_shape[:2] - gn = torch.tensor(img_shape)[[1, 0, 1, 0]] # normalization gain whwh - gn_lks = torch.tensor(img_shape)[ - [1, 0, 1, 0, 1, 0, 1, 0, 1, 0] - ] # normalization gain landmarks - det = pred[image_id].cpu() - scale_coords(imgs[image_id].shape[1:], det[:, :4], img_shape).round() - scale_coords_landmarks( - imgs[image_id].shape[1:], det[:, 5:15], img_shape - ).round() - - for j in range(det.size()[0]): - box = (det[j, :4].view(1, 4) / gn).view(-1).tolist() - box = list( - map( - int, - [ - box[0] * image_width, - box[1] * image_height, - box[2] * image_width, - box[3] * image_height, - ], - ) - ) - if box[3] - box[1] < self.min_face: - continue - lm = (det[j, 5:15].view(1, 10) / gn_lks).view(-1).tolist() - lm = list( - map( - int, - [ - i * image_width if j % 2 == 0 else i * image_height - for j, i in enumerate(lm) - ], - ) - ) - lm = [lm[i : i + 2] for i in range(0, len(lm), 2)] - bboxes[image_id].append(box) - landmarks[image_id].append(lm) - return bboxes, landmarks - - def detect_faces(self, imgs, conf_thres=0.7, iou_thres=0.5): - """ - Get bbox coordinates and keypoints of faces on original image. - Params: - imgs: image or list of images to detect faces on with BGR order (convert to RGB order for inference) - conf_thres: confidence threshold for each prediction - iou_thres: threshold for NMS (filter of intersecting bboxes) - Returns: - bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2. - points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners). - """ - # Pass input images through face detector - images = imgs if isinstance(imgs, list) else [imgs] - images = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in images] - origimgs = copy.deepcopy(images) - - images = self._preprocess(images) - - if is_high_version(): - with torch.inference_mode(): # for pytorch>=1.9 - pred = self.detector(images)[0] - else: - with torch.no_grad(): # for pytorch<1.9 - pred = self.detector(images)[0] - - bboxes, points = self._postprocess( - images, origimgs, pred, conf_thres, iou_thres - ) - - # return bboxes, points - if not isListempty(points): - bboxes = np.array(bboxes).reshape(-1, 4) - points = np.array(points).reshape(-1, 10) - padding = bboxes[:, 0].reshape(-1, 1) - return np.concatenate((bboxes, padding, points), axis=1) - else: - return None - - def __call__(self, *args): - return self.predict(*args) +import copy +from pathlib import Path + +import cv2 +import numpy as np +import torch + +from hordelib.nodes.facerestore_cf.facelib.detection.yolov5face.models.yolo import Model +from hordelib.nodes.facerestore_cf.facelib.detection.yolov5face.utils.datasets import letterbox +from hordelib.nodes.facerestore_cf.facelib.detection.yolov5face.utils.general import ( + check_img_size, + non_max_suppression_face, + scale_coords, + scale_coords_landmarks, +) + +try: + version_str = torch.__version__.split("+")[0] + major, minor, patch = map(int, version_str.split(".")) + IS_HIGH_VERSION = (major, minor, patch) >= (1, 9, 0) +except ValueError: + # Handle the case of a development version here + IS_HIGH_VERSION = False + + +def isListempty(inList): + if isinstance(inList, list): # Is a list + return all(map(isListempty, inList)) + return False # Not a list + + +class YoloDetector: + def __init__( + self, + config_name, + min_face=10, + target_size=None, + device="cuda", + ): + """ + config_name: name of .yaml config with network configuration from models/ folder. + min_face : minimal face size in pixels. + target_size : target size of smaller image axis (choose lower for faster work). e.g. 480, 720, 1080. + None for original resolution. + """ + self._class_path = Path(__file__).parent.absolute() + self.target_size = target_size + self.min_face = min_face + self.detector = Model(cfg=config_name) + self.device = device + + def _preprocess(self, imgs): + """ + Preprocessing image before passing through the network. Resize and conversion to torch tensor. + """ + pp_imgs = [] + for img in imgs: + h0, w0 = img.shape[:2] # orig hw + if self.target_size: + r = self.target_size / min(h0, w0) # resize image to img_size + if r < 1: + img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=cv2.INTER_LINEAR) + + imgsz = check_img_size(max(img.shape[:2]), s=self.detector.stride.max()) # check img_size + img = letterbox(img, new_shape=imgsz)[0] + pp_imgs.append(img) + pp_imgs = np.array(pp_imgs) + pp_imgs = pp_imgs.transpose(0, 3, 1, 2) + pp_imgs = torch.from_numpy(pp_imgs).to(self.device) + pp_imgs = pp_imgs.float() # uint8 to fp16/32 + return pp_imgs / 255.0 # 0 - 255 to 0.0 - 1.0 + + def _postprocess(self, imgs, origimgs, pred, conf_thres, iou_thres): + """ + Postprocessing of raw pytorch model output. + Returns: + bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2. + points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners). + """ + bboxes = [[] for _ in range(len(origimgs))] + landmarks = [[] for _ in range(len(origimgs))] + + pred = non_max_suppression_face(pred, conf_thres, iou_thres) + + for image_id, origimg in enumerate(origimgs): + img_shape = origimg.shape + image_height, image_width = img_shape[:2] + gn = torch.tensor(img_shape)[[1, 0, 1, 0]] # normalization gain whwh + gn_lks = torch.tensor(img_shape)[[1, 0, 1, 0, 1, 0, 1, 0, 1, 0]] # normalization gain landmarks + det = pred[image_id].cpu() + scale_coords(imgs[image_id].shape[1:], det[:, :4], img_shape).round() + scale_coords_landmarks(imgs[image_id].shape[1:], det[:, 5:15], img_shape).round() + + for j in range(det.size()[0]): + box = (det[j, :4].view(1, 4) / gn).view(-1).tolist() + box = list( + map( + int, + [box[0] * image_width, box[1] * image_height, box[2] * image_width, box[3] * image_height], + ), + ) + if box[3] - box[1] < self.min_face: + continue + lm = (det[j, 5:15].view(1, 10) / gn_lks).view(-1).tolist() + lm = list(map(int, [i * image_width if j % 2 == 0 else i * image_height for j, i in enumerate(lm)])) + lm = [lm[i : i + 2] for i in range(0, len(lm), 2)] + bboxes[image_id].append(box) + landmarks[image_id].append(lm) + return bboxes, landmarks + + def detect_faces(self, imgs, conf_thres=0.7, iou_thres=0.5): + """ + Get bbox coordinates and keypoints of faces on original image. + Params: + imgs: image or list of images to detect faces on with BGR order (convert to RGB order for inference) + conf_thres: confidence threshold for each prediction + iou_thres: threshold for NMS (filter of intersecting bboxes) + Returns: + bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2. + points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners). + """ + # Pass input images through face detector + images = imgs if isinstance(imgs, list) else [imgs] + images = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in images] + origimgs = copy.deepcopy(images) + + images = self._preprocess(images) + + if IS_HIGH_VERSION: + with torch.inference_mode(): # for pytorch>=1.9 + pred = self.detector(images)[0] + else: + with torch.no_grad(): # for pytorch<1.9 + pred = self.detector(images)[0] + + bboxes, points = self._postprocess(images, origimgs, pred, conf_thres, iou_thres) + + # return bboxes, points + if not isListempty(points): + bboxes = np.array(bboxes).reshape(-1, 4) + points = np.array(points).reshape(-1, 10) + padding = bboxes[:, 0].reshape(-1, 1) + return np.concatenate((bboxes, padding, points), axis=1) + else: + return None + + def __call__(self, *args): + return self.predict(*args) diff --git a/hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/__init__.py b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/models/__init__.py similarity index 100% rename from hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/__init__.py rename to hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/models/__init__.py diff --git a/hordelib/nodes/facerestore/facelib/detection/yolov5face/models/common.py b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/models/common.py similarity index 92% rename from hordelib/nodes/facerestore/facelib/detection/yolov5face/models/common.py rename to hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/models/common.py index 9c6c5eaa..d9559779 100644 --- a/hordelib/nodes/facerestore/facelib/detection/yolov5face/models/common.py +++ b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/models/common.py @@ -6,8 +6,8 @@ import torch from torch import nn -from hordelib.nodes.facerestore.facelib.detection.yolov5face.utils.datasets import letterbox -from hordelib.nodes.facerestore.facelib.detection.yolov5face.utils.general import ( +from hordelib.nodes.facerestore_cf.facelib.detection.yolov5face.utils.datasets import letterbox +from hordelib.nodes.facerestore_cf.facelib.detection.yolov5face.utils.general import ( make_divisible, non_max_suppression, scale_coords, @@ -149,22 +149,9 @@ def __init__(self, inp, oup, stride): ), nn.BatchNorm2d(branch_features), nn.SiLU(), - self.depthwise_conv( - branch_features, - branch_features, - kernel_size=3, - stride=self.stride, - padding=1, - ), + self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), nn.BatchNorm2d(branch_features), - nn.Conv2d( - branch_features, - branch_features, - kernel_size=1, - stride=1, - padding=0, - bias=False, - ), + nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(branch_features), nn.SiLU(), ) @@ -204,17 +191,7 @@ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, k self.conv = Conv(c1 * 4, c2, k, s, p, g, act) def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2) - return self.conv( - torch.cat( - [ - x[..., ::2, ::2], - x[..., 1::2, ::2], - x[..., ::2, 1::2], - x[..., 1::2, 1::2], - ], - 1, - ) - ) + return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)) class Concat(nn.Module): diff --git a/hordelib/nodes/facerestore/facelib/detection/yolov5face/models/experimental.py b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/models/experimental.py similarity index 78% rename from hordelib/nodes/facerestore/facelib/detection/yolov5face/models/experimental.py rename to hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/models/experimental.py index c33acba5..e04914e9 100644 --- a/hordelib/nodes/facerestore/facelib/detection/yolov5face/models/experimental.py +++ b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/models/experimental.py @@ -4,7 +4,7 @@ import torch from torch import nn -from hordelib.nodes.facerestore.facelib.detection.yolov5face.models.common import Conv +from hordelib.nodes.facerestore_cf.facelib.detection.yolov5face.models.common import Conv class CrossConv(nn.Module): @@ -35,16 +35,9 @@ def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): a -= np.roll(a, 1, axis=1) a *= np.array(k) ** 2 a[0] = 1 - c_ = np.linalg.lstsq(a, b, rcond=None)[ - 0 - ].round() # solve for equal weight indices, ax = b - - self.m = nn.ModuleList( - [ - nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) - for g in range(groups) - ] - ) + c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b + + self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)]) self.bn = nn.BatchNorm2d(c2) self.act = nn.LeakyReLU(0.1, inplace=True) diff --git a/hordelib/nodes/facerestore/facelib/detection/yolov5face/models/yolo.py b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/models/yolo.py similarity index 64% rename from hordelib/nodes/facerestore/facelib/detection/yolov5face/models/yolo.py rename to hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/models/yolo.py index b11af3e4..fe22178d 100644 --- a/hordelib/nodes/facerestore/facelib/detection/yolov5face/models/yolo.py +++ b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/models/yolo.py @@ -6,7 +6,7 @@ import yaml # for torch hub from torch import nn -from hordelib.nodes.facerestore.facelib.detection.yolov5face.models.common import ( +from hordelib.nodes.facerestore_cf.facelib.detection.yolov5face.models.common import ( C3, NMS, SPP, @@ -20,16 +20,10 @@ ShuffleV2Block, StemBlock, ) -from hordelib.nodes.facerestore.facelib.detection.yolov5face.models.experimental import ( - CrossConv, - MixConv2d, -) -from hordelib.nodes.facerestore.facelib.detection.yolov5face.utils.autoanchor import check_anchor_order -from hordelib.nodes.facerestore.facelib.detection.yolov5face.utils.general import make_divisible -from hordelib.nodes.facerestore.facelib.detection.yolov5face.utils.torch_utils import ( - copy_attr, - fuse_conv_and_bn, -) +from hordelib.nodes.facerestore_cf.facelib.detection.yolov5face.models.experimental import CrossConv, MixConv2d +from hordelib.nodes.facerestore_cf.facelib.detection.yolov5face.utils.autoanchor import check_anchor_order +from hordelib.nodes.facerestore_cf.facelib.detection.yolov5face.utils.general import make_divisible +from hordelib.nodes.facerestore_cf.facelib.detection.yolov5face.utils.torch_utils import copy_attr, fuse_conv_and_bn class Detect(nn.Module): @@ -46,12 +40,8 @@ def __init__(self, nc=80, anchors=(), ch=()): # detection layer self.grid = [torch.zeros(1)] * self.nl # init grid a = torch.tensor(anchors).float().view(self.nl, -1, 2) self.register_buffer("anchors", a) # shape(nl,na,2) - self.register_buffer( - "anchor_grid", a.clone().view(self.nl, 1, -1, 1, 1, 2) - ) # shape(nl,1,na,1,1,2) - self.m = nn.ModuleList( - nn.Conv2d(x, self.no * self.na, 1) for x in ch - ) # output conv + self.register_buffer("anchor_grid", a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2) + self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv def forward(self, x): z = [] # inference output @@ -62,12 +52,7 @@ def forward(self, x): for i in range(self.nl): x[i] = self.m[i](x[i]) # conv bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85) - x[i] = ( - x[i] - .view(bs, self.na, self.no, ny, nx) - .permute(0, 1, 3, 4, 2) - .contiguous() - ) + x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() if not self.training: # inference if self.grid[i].shape[2:4] != x[i].shape[2:4]: @@ -77,32 +62,23 @@ def forward(self, x): y[..., [0, 1, 2, 3, 4, 15]] = x[i][..., [0, 1, 2, 3, 4, 15]].sigmoid() y[..., 5:15] = x[i][..., 5:15] - y[..., 0:2] = ( - y[..., 0:2] * 2.0 - 0.5 + self.grid[i].to(x[i].device) - ) * self.stride[ - i - ] # xy + y[..., 0:2] = (y[..., 0:2] * 2.0 - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh y[..., 5:7] = ( - y[..., 5:7] * self.anchor_grid[i] - + self.grid[i].to(x[i].device) * self.stride[i] + y[..., 5:7] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] ) # landmark x1 y1 y[..., 7:9] = ( - y[..., 7:9] * self.anchor_grid[i] - + self.grid[i].to(x[i].device) * self.stride[i] + y[..., 7:9] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] ) # landmark x2 y2 y[..., 9:11] = ( - y[..., 9:11] * self.anchor_grid[i] - + self.grid[i].to(x[i].device) * self.stride[i] + y[..., 9:11] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] ) # landmark x3 y3 y[..., 11:13] = ( - y[..., 11:13] * self.anchor_grid[i] - + self.grid[i].to(x[i].device) * self.stride[i] + y[..., 11:13] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] ) # landmark x4 y4 y[..., 13:15] = ( - y[..., 13:15] * self.anchor_grid[i] - + self.grid[i].to(x[i].device) * self.stride[i] + y[..., 13:15] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] ) # landmark x5 y5 z.append(y.view(bs, -1, self.no)) @@ -117,9 +93,7 @@ def _make_grid(nx=20, ny=20): class Model(nn.Module): - def __init__( - self, cfg="yolov5s.yaml", ch=3, nc=None - ): # model, input channels, number of classes + def __init__(self, cfg="yolov5s.yaml", ch=3, nc=None): # model, input channels, number of classes super().__init__() self.yaml_file = Path(cfg).name with Path(cfg).open(encoding="utf8") as f: @@ -130,18 +104,14 @@ def __init__( if nc and nc != self.yaml["nc"]: self.yaml["nc"] = nc # override yaml value - self.model, self.save = parse_model( - deepcopy(self.yaml), ch=[ch] - ) # model, savelist + self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist self.names = [str(i) for i in range(self.yaml["nc"])] # default names # Build strides, anchors m = self.model[-1] # Detect() if isinstance(m, Detect): s = 128 # 2x min stride - m.stride = torch.tensor( - [s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))] - ) # forward + m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward m.anchors /= m.stride.view(-1, 1, 1) check_anchor_order(m) self.stride = m.stride @@ -154,42 +124,27 @@ def forward_once(self, x): y = [] # outputs for m in self.model: if m.f != -1: # if not from previous layer - x = ( - y[m.f] - if isinstance(m.f, int) - else [x if j == -1 else y[j] for j in m.f] - ) # from earlier layers + x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers x = m(x) # run y.append(x if m.i in self.save else None) # save output return x - def _initialize_biases( - self, cf=None - ): # initialize biases into Detect(), cf is class frequency + def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency # https://arxiv.org/abs/1708.02002 section 3.3 m = self.model[-1] # Detect() module for mi, s in zip(m.m, m.stride): # from b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85) - b.data[:, 4] += math.log( - 8 / (640 / s) ** 2 - ) # obj (8 objects per 640 image) - b.data[:, 5:] += ( - math.log(0.6 / (m.nc - 0.99)) - if cf is None - else torch.log(cf / cf.sum()) - ) # cls + b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image) + b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) def _print_biases(self): m = self.model[-1] # Detect() module for mi in m.m: # from b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85) - print( - ("%6g Conv2d.bias:" + "%10.3g" * 6) - % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()) - ) + print(("%6g Conv2d.bias:" + "%10.3g" * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean())) def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers print("Fusing layers... ") @@ -219,28 +174,17 @@ def nms(self, mode=True): # add or remove NMS module def autoshape(self): # add autoShape module print("Adding autoShape... ") m = AutoShape(self) # wrap model - copy_attr( - m, self, include=("yaml", "nc", "hyp", "names", "stride"), exclude=() - ) # copy attributes + copy_attr(m, self, include=("yaml", "nc", "hyp", "names", "stride"), exclude=()) # copy attributes return m def parse_model(d, ch): # model_dict, input_channels(3) - anchors, nc, gd, gw = ( - d["anchors"], - d["nc"], - d["depth_multiple"], - d["width_multiple"], - ) - na = ( - (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors - ) # number of anchors + anchors, nc, gd, gw = d["anchors"], d["nc"], d["depth_multiple"], d["width_multiple"] + na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors no = na * (nc + 5) # number of outputs = anchors * (classes + 5) layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out - for i, (f, n, m, args) in enumerate( - d["backbone"] + d["head"] - ): # from, number, module, args + for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args m = eval(m) if isinstance(m, str) else m # eval strings for j, a in enumerate(args): try: @@ -281,20 +225,11 @@ def parse_model(d, ch): # model_dict, input_channels(3) else: c2 = ch[f] - m_ = ( - nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) - ) # module + m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module t = str(m)[8:-2].replace("__main__.", "") # module type np = sum(x.numel() for x in m_.parameters()) # number params - m_.i, m_.f, m_.type, m_.np = ( - i, - f, - t, - np, - ) # attach index, 'from' index, type, number params - save.extend( - x % i for x in ([f] if isinstance(f, int) else f) if x != -1 - ) # append to savelist + m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params + save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist layers.append(m_) ch.append(c2) return nn.Sequential(*layers), sorted(save) diff --git a/hordelib/nodes/facerestore/facelib/detection/yolov5face/models/yolov5l.yaml b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/models/yolov5l.yaml similarity index 100% rename from hordelib/nodes/facerestore/facelib/detection/yolov5face/models/yolov5l.yaml rename to hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/models/yolov5l.yaml diff --git a/hordelib/nodes/facerestore/facelib/detection/yolov5face/models/yolov5n.yaml b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/models/yolov5n.yaml similarity index 100% rename from hordelib/nodes/facerestore/facelib/detection/yolov5face/models/yolov5n.yaml rename to hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/models/yolov5n.yaml diff --git a/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/__init__.py b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/autoanchor.py b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/autoanchor.py similarity index 97% rename from hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/autoanchor.py rename to hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/autoanchor.py index a4eba3e9..cb0de894 100644 --- a/hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/autoanchor.py +++ b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/autoanchor.py @@ -1,12 +1,12 @@ -# Auto-anchor utils - - -def check_anchor_order(m): - # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary - a = m.anchor_grid.prod(-1).view(-1) # anchor area - da = a[-1] - a[0] # delta a - ds = m.stride[-1] - m.stride[0] # delta s - if da.sign() != ds.sign(): # same order - print("Reversing anchor order") - m.anchors[:] = m.anchors.flip(0) - m.anchor_grid[:] = m.anchor_grid.flip(0) +# Auto-anchor utils + + +def check_anchor_order(m): + # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary + a = m.anchor_grid.prod(-1).view(-1) # anchor area + da = a[-1] - a[0] # delta a + ds = m.stride[-1] - m.stride[0] # delta s + if da.sign() != ds.sign(): # same order + print("Reversing anchor order") + m.anchors[:] = m.anchors.flip(0) + m.anchor_grid[:] = m.anchor_grid.flip(0) diff --git a/hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/datasets.py b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/datasets.py similarity index 97% rename from hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/datasets.py rename to hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/datasets.py index e672b136..a72609b4 100644 --- a/hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/datasets.py +++ b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/datasets.py @@ -1,35 +1,35 @@ -import cv2 -import numpy as np - - -def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale_fill=False, scaleup=True): - # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232 - shape = img.shape[:2] # current shape [height, width] - if isinstance(new_shape, int): - new_shape = (new_shape, new_shape) - - # Scale ratio (new / old) - r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) - if not scaleup: # only scale down, do not scale up (for better test mAP) - r = min(r, 1.0) - - # Compute padding - ratio = r, r # width, height ratios - new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) - dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding - if auto: # minimum rectangle - dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding - elif scale_fill: # stretch - dw, dh = 0.0, 0.0 - new_unpad = (new_shape[1], new_shape[0]) - ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios - - dw /= 2 # divide padding into 2 sides - dh /= 2 - - if shape[::-1] != new_unpad: # resize - img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) - top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) - left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) - img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border - return img, ratio, (dw, dh) +import cv2 +import numpy as np + + +def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale_fill=False, scaleup=True): + # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232 + shape = img.shape[:2] # current shape [height, width] + if isinstance(new_shape, int): + new_shape = (new_shape, new_shape) + + # Scale ratio (new / old) + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + if not scaleup: # only scale down, do not scale up (for better test mAP) + r = min(r, 1.0) + + # Compute padding + ratio = r, r # width, height ratios + new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) + dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding + if auto: # minimum rectangle + dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding + elif scale_fill: # stretch + dw, dh = 0.0, 0.0 + new_unpad = (new_shape[1], new_shape[0]) + ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios + + dw /= 2 # divide padding into 2 sides + dh /= 2 + + if shape[::-1] != new_unpad: # resize + img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border + return img, ratio, (dw, dh) diff --git a/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/extract_ckpt.py b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/extract_ckpt.py new file mode 100644 index 00000000..cddcbb07 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/extract_ckpt.py @@ -0,0 +1,7 @@ +import sys + +import torch + +sys.path.insert(0, "./facelib/detection/yolov5face") +model = torch.load("facelib/detection/yolov5face/yolov5n-face.pt", map_location="cpu")["model"] +torch.save(model.state_dict(), "../../models/facedetection") diff --git a/hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/general.py b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/general.py similarity index 97% rename from hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/general.py rename to hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/general.py index 1c8e14f5..618d2f31 100644 --- a/hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/general.py +++ b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/general.py @@ -1,271 +1,271 @@ -import math -import time - -import numpy as np -import torch -import torchvision - - -def check_img_size(img_size, s=32): - # Verify img_size is a multiple of stride s - new_size = make_divisible(img_size, int(s)) # ceil gs-multiple - # if new_size != img_size: - # print(f"WARNING: --img-size {img_size:g} must be multiple of max stride {s:g}, updating to {new_size:g}") - return new_size - - -def make_divisible(x, divisor): - # Returns x evenly divisible by divisor - return math.ceil(x / divisor) * divisor - - -def xyxy2xywh(x): - # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right - y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) - y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center - y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center - y[:, 2] = x[:, 2] - x[:, 0] # width - y[:, 3] = x[:, 3] - x[:, 1] # height - return y - - -def xywh2xyxy(x): - # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right - y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) - y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x - y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y - y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x - y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y - return y - - -def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None): - # Rescale coords (xyxy) from img1_shape to img0_shape - if ratio_pad is None: # calculate from img0_shape - gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new - pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding - else: - gain = ratio_pad[0][0] - pad = ratio_pad[1] - - coords[:, [0, 2]] -= pad[0] # x padding - coords[:, [1, 3]] -= pad[1] # y padding - coords[:, :4] /= gain - clip_coords(coords, img0_shape) - return coords - - -def clip_coords(boxes, img_shape): - # Clip bounding xyxy bounding boxes to image shape (height, width) - boxes[:, 0].clamp_(0, img_shape[1]) # x1 - boxes[:, 1].clamp_(0, img_shape[0]) # y1 - boxes[:, 2].clamp_(0, img_shape[1]) # x2 - boxes[:, 3].clamp_(0, img_shape[0]) # y2 - - -def box_iou(box1, box2): - # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py - """ - Return intersection-over-union (Jaccard index) of boxes. - Both sets of boxes are expected to be in (x1, y1, x2, y2) format. - Arguments: - box1 (Tensor[N, 4]) - box2 (Tensor[M, 4]) - Returns: - iou (Tensor[N, M]): the NxM matrix containing the pairwise - IoU values for every element in boxes1 and boxes2 - """ - - def box_area(box): - return (box[2] - box[0]) * (box[3] - box[1]) - - area1 = box_area(box1.T) - area2 = box_area(box2.T) - - inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2) - return inter / (area1[:, None] + area2 - inter) - - -def non_max_suppression_face(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()): - """Performs Non-Maximum Suppression (NMS) on inference results - Returns: - detections with shape: nx6 (x1, y1, x2, y2, conf, cls) - """ - - nc = prediction.shape[2] - 15 # number of classes - xc = prediction[..., 4] > conf_thres # candidates - - # Settings - # (pixels) maximum box width and height - max_wh = 4096 - time_limit = 10.0 # seconds to quit after - redundant = True # require redundant detections - multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) - merge = False # use merge-NMS - - t = time.time() - output = [torch.zeros((0, 16), device=prediction.device)] * prediction.shape[0] - for xi, x in enumerate(prediction): # image index, image inference - # Apply constraints - x = x[xc[xi]] # confidence - - # Cat apriori labels if autolabelling - if labels and len(labels[xi]): - label = labels[xi] - v = torch.zeros((len(label), nc + 15), device=x.device) - v[:, :4] = label[:, 1:5] # box - v[:, 4] = 1.0 # conf - v[range(len(label)), label[:, 0].long() + 15] = 1.0 # cls - x = torch.cat((x, v), 0) - - # If none remain process next image - if not x.shape[0]: - continue - - # Compute conf - x[:, 15:] *= x[:, 4:5] # conf = obj_conf * cls_conf - - # Box (center x, center y, width, height) to (x1, y1, x2, y2) - box = xywh2xyxy(x[:, :4]) - - # Detections matrix nx6 (xyxy, conf, landmarks, cls) - if multi_label: - i, j = (x[:, 15:] > conf_thres).nonzero(as_tuple=False).T - x = torch.cat((box[i], x[i, j + 15, None], x[:, 5:15], j[:, None].float()), 1) - else: # best class only - conf, j = x[:, 15:].max(1, keepdim=True) - x = torch.cat((box, conf, x[:, 5:15], j.float()), 1)[conf.view(-1) > conf_thres] - - # Filter by class - if classes is not None: - x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] - - # If none remain process next image - n = x.shape[0] # number of boxes - if not n: - continue - - # Batched NMS - c = x[:, 15:16] * (0 if agnostic else max_wh) # classes - boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores - i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS - - if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean) - # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) - iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix - weights = iou * scores[None] # box weights - x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes - if redundant: - i = i[iou.sum(1) > 1] # require redundancy - - output[xi] = x[i] - if (time.time() - t) > time_limit: - break # time limit exceeded - - return output - - -def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()): - """Performs Non-Maximum Suppression (NMS) on inference results - - Returns: - detections with shape: nx6 (x1, y1, x2, y2, conf, cls) - """ - - nc = prediction.shape[2] - 5 # number of classes - xc = prediction[..., 4] > conf_thres # candidates - - # Settings - # (pixels) maximum box width and height - max_wh = 4096 - time_limit = 10.0 # seconds to quit after - redundant = True # require redundant detections - multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) - merge = False # use merge-NMS - - t = time.time() - output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0] - for xi, x in enumerate(prediction): # image index, image inference - x = x[xc[xi]] # confidence - - # Cat apriori labels if autolabelling - if labels and len(labels[xi]): - label_id = labels[xi] - v = torch.zeros((len(label_id), nc + 5), device=x.device) - v[:, :4] = label_id[:, 1:5] # box - v[:, 4] = 1.0 # conf - v[range(len(label_id)), label_id[:, 0].long() + 5] = 1.0 # cls - x = torch.cat((x, v), 0) - - # If none remain process next image - if not x.shape[0]: - continue - - # Compute conf - x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf - - # Box (center x, center y, width, height) to (x1, y1, x2, y2) - box = xywh2xyxy(x[:, :4]) - - # Detections matrix nx6 (xyxy, conf, cls) - if multi_label: - i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T - x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) - else: # best class only - conf, j = x[:, 5:].max(1, keepdim=True) - x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] - - # Filter by class - if classes is not None: - x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] - - # Check shape - n = x.shape[0] # number of boxes - if not n: # no boxes - continue - - x = x[x[:, 4].argsort(descending=True)] # sort by confidence - - # Batched NMS - c = x[:, 5:6] * (0 if agnostic else max_wh) # classes - boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores - i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS - if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean) - # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) - iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix - weights = iou * scores[None] # box weights - x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes - if redundant: - i = i[iou.sum(1) > 1] # require redundancy - - output[xi] = x[i] - if (time.time() - t) > time_limit: - print(f"WARNING: NMS time limit {time_limit}s exceeded") - break # time limit exceeded - - return output - - -def scale_coords_landmarks(img1_shape, coords, img0_shape, ratio_pad=None): - # Rescale coords (xyxy) from img1_shape to img0_shape - if ratio_pad is None: # calculate from img0_shape - gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new - pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding - else: - gain = ratio_pad[0][0] - pad = ratio_pad[1] - - coords[:, [0, 2, 4, 6, 8]] -= pad[0] # x padding - coords[:, [1, 3, 5, 7, 9]] -= pad[1] # y padding - coords[:, :10] /= gain - coords[:, 0].clamp_(0, img0_shape[1]) # x1 - coords[:, 1].clamp_(0, img0_shape[0]) # y1 - coords[:, 2].clamp_(0, img0_shape[1]) # x2 - coords[:, 3].clamp_(0, img0_shape[0]) # y2 - coords[:, 4].clamp_(0, img0_shape[1]) # x3 - coords[:, 5].clamp_(0, img0_shape[0]) # y3 - coords[:, 6].clamp_(0, img0_shape[1]) # x4 - coords[:, 7].clamp_(0, img0_shape[0]) # y4 - coords[:, 8].clamp_(0, img0_shape[1]) # x5 - coords[:, 9].clamp_(0, img0_shape[0]) # y5 - return coords +import math +import time + +import numpy as np +import torch +import torchvision + + +def check_img_size(img_size, s=32): + # Verify img_size is a multiple of stride s + new_size = make_divisible(img_size, int(s)) # ceil gs-multiple + # if new_size != img_size: + # print(f"WARNING: --img-size {img_size:g} must be multiple of max stride {s:g}, updating to {new_size:g}") + return new_size + + +def make_divisible(x, divisor): + # Returns x evenly divisible by divisor + return math.ceil(x / divisor) * divisor + + +def xyxy2xywh(x): + # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center + y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center + y[:, 2] = x[:, 2] - x[:, 0] # width + y[:, 3] = x[:, 3] - x[:, 1] # height + return y + + +def xywh2xyxy(x): + # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x + y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y + y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x + y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y + return y + + +def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None): + # Rescale coords (xyxy) from img1_shape to img0_shape + if ratio_pad is None: # calculate from img0_shape + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + coords[:, [0, 2]] -= pad[0] # x padding + coords[:, [1, 3]] -= pad[1] # y padding + coords[:, :4] /= gain + clip_coords(coords, img0_shape) + return coords + + +def clip_coords(boxes, img_shape): + # Clip bounding xyxy bounding boxes to image shape (height, width) + boxes[:, 0].clamp_(0, img_shape[1]) # x1 + boxes[:, 1].clamp_(0, img_shape[0]) # y1 + boxes[:, 2].clamp_(0, img_shape[1]) # x2 + boxes[:, 3].clamp_(0, img_shape[0]) # y2 + + +def box_iou(box1, box2): + # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py + """ + Return intersection-over-union (Jaccard index) of boxes. + Both sets of boxes are expected to be in (x1, y1, x2, y2) format. + Arguments: + box1 (Tensor[N, 4]) + box2 (Tensor[M, 4]) + Returns: + iou (Tensor[N, M]): the NxM matrix containing the pairwise + IoU values for every element in boxes1 and boxes2 + """ + + def box_area(box): + return (box[2] - box[0]) * (box[3] - box[1]) + + area1 = box_area(box1.T) + area2 = box_area(box2.T) + + inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2) + return inter / (area1[:, None] + area2 - inter) + + +def non_max_suppression_face(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()): + """Performs Non-Maximum Suppression (NMS) on inference results + Returns: + detections with shape: nx6 (x1, y1, x2, y2, conf, cls) + """ + + nc = prediction.shape[2] - 15 # number of classes + xc = prediction[..., 4] > conf_thres # candidates + + # Settings + # (pixels) maximum box width and height + max_wh = 4096 + time_limit = 10.0 # seconds to quit after + redundant = True # require redundant detections + multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) + merge = False # use merge-NMS + + t = time.time() + output = [torch.zeros((0, 16), device=prediction.device)] * prediction.shape[0] + for xi, x in enumerate(prediction): # image index, image inference + # Apply constraints + x = x[xc[xi]] # confidence + + # Cat apriori labels if autolabelling + if labels and len(labels[xi]): + label = labels[xi] + v = torch.zeros((len(label), nc + 15), device=x.device) + v[:, :4] = label[:, 1:5] # box + v[:, 4] = 1.0 # conf + v[range(len(label)), label[:, 0].long() + 15] = 1.0 # cls + x = torch.cat((x, v), 0) + + # If none remain process next image + if not x.shape[0]: + continue + + # Compute conf + x[:, 15:] *= x[:, 4:5] # conf = obj_conf * cls_conf + + # Box (center x, center y, width, height) to (x1, y1, x2, y2) + box = xywh2xyxy(x[:, :4]) + + # Detections matrix nx6 (xyxy, conf, landmarks, cls) + if multi_label: + i, j = (x[:, 15:] > conf_thres).nonzero(as_tuple=False).T + x = torch.cat((box[i], x[i, j + 15, None], x[:, 5:15], j[:, None].float()), 1) + else: # best class only + conf, j = x[:, 15:].max(1, keepdim=True) + x = torch.cat((box, conf, x[:, 5:15], j.float()), 1)[conf.view(-1) > conf_thres] + + # Filter by class + if classes is not None: + x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] + + # If none remain process next image + n = x.shape[0] # number of boxes + if not n: + continue + + # Batched NMS + c = x[:, 15:16] * (0 if agnostic else max_wh) # classes + boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores + i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS + + if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean) + # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) + iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix + weights = iou * scores[None] # box weights + x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes + if redundant: + i = i[iou.sum(1) > 1] # require redundancy + + output[xi] = x[i] + if (time.time() - t) > time_limit: + break # time limit exceeded + + return output + + +def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()): + """Performs Non-Maximum Suppression (NMS) on inference results + + Returns: + detections with shape: nx6 (x1, y1, x2, y2, conf, cls) + """ + + nc = prediction.shape[2] - 5 # number of classes + xc = prediction[..., 4] > conf_thres # candidates + + # Settings + # (pixels) maximum box width and height + max_wh = 4096 + time_limit = 10.0 # seconds to quit after + redundant = True # require redundant detections + multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) + merge = False # use merge-NMS + + t = time.time() + output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0] + for xi, x in enumerate(prediction): # image index, image inference + x = x[xc[xi]] # confidence + + # Cat apriori labels if autolabelling + if labels and len(labels[xi]): + label_id = labels[xi] + v = torch.zeros((len(label_id), nc + 5), device=x.device) + v[:, :4] = label_id[:, 1:5] # box + v[:, 4] = 1.0 # conf + v[range(len(label_id)), label_id[:, 0].long() + 5] = 1.0 # cls + x = torch.cat((x, v), 0) + + # If none remain process next image + if not x.shape[0]: + continue + + # Compute conf + x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf + + # Box (center x, center y, width, height) to (x1, y1, x2, y2) + box = xywh2xyxy(x[:, :4]) + + # Detections matrix nx6 (xyxy, conf, cls) + if multi_label: + i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T + x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) + else: # best class only + conf, j = x[:, 5:].max(1, keepdim=True) + x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] + + # Filter by class + if classes is not None: + x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] + + # Check shape + n = x.shape[0] # number of boxes + if not n: # no boxes + continue + + x = x[x[:, 4].argsort(descending=True)] # sort by confidence + + # Batched NMS + c = x[:, 5:6] * (0 if agnostic else max_wh) # classes + boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores + i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS + if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean) + # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) + iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix + weights = iou * scores[None] # box weights + x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes + if redundant: + i = i[iou.sum(1) > 1] # require redundancy + + output[xi] = x[i] + if (time.time() - t) > time_limit: + print(f"WARNING: NMS time limit {time_limit}s exceeded") + break # time limit exceeded + + return output + + +def scale_coords_landmarks(img1_shape, coords, img0_shape, ratio_pad=None): + # Rescale coords (xyxy) from img1_shape to img0_shape + if ratio_pad is None: # calculate from img0_shape + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + coords[:, [0, 2, 4, 6, 8]] -= pad[0] # x padding + coords[:, [1, 3, 5, 7, 9]] -= pad[1] # y padding + coords[:, :10] /= gain + coords[:, 0].clamp_(0, img0_shape[1]) # x1 + coords[:, 1].clamp_(0, img0_shape[0]) # y1 + coords[:, 2].clamp_(0, img0_shape[1]) # x2 + coords[:, 3].clamp_(0, img0_shape[0]) # y2 + coords[:, 4].clamp_(0, img0_shape[1]) # x3 + coords[:, 5].clamp_(0, img0_shape[0]) # y3 + coords[:, 6].clamp_(0, img0_shape[1]) # x4 + coords[:, 7].clamp_(0, img0_shape[0]) # y4 + coords[:, 8].clamp_(0, img0_shape[1]) # x5 + coords[:, 9].clamp_(0, img0_shape[0]) # y5 + return coords diff --git a/hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/torch_utils.py b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/torch_utils.py similarity index 97% rename from hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/torch_utils.py rename to hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/torch_utils.py index af2d0658..f7029623 100644 --- a/hordelib/nodes/facerestore/facelib/detection/yolov5face/utils/torch_utils.py +++ b/hordelib/nodes/facerestore_cf/facelib/detection/yolov5face/utils/torch_utils.py @@ -1,40 +1,40 @@ -import torch -from torch import nn - - -def fuse_conv_and_bn(conv, bn): - # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ - fusedconv = ( - nn.Conv2d( - conv.in_channels, - conv.out_channels, - kernel_size=conv.kernel_size, - stride=conv.stride, - padding=conv.padding, - groups=conv.groups, - bias=True, - ) - .requires_grad_(False) - .to(conv.weight.device) - ) - - # prepare filters - w_conv = conv.weight.clone().view(conv.out_channels, -1) - w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) - fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size())) - - # prepare spatial bias - b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias - b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) - fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) - - return fusedconv - - -def copy_attr(a, b, include=(), exclude=()): - # Copy attributes from b to a, options to only include [...] and to exclude [...] - for k, v in b.__dict__.items(): - if (include and k not in include) or k.startswith("_") or k in exclude: - continue - - setattr(a, k, v) +import torch +from torch import nn + + +def fuse_conv_and_bn(conv, bn): + # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ + fusedconv = ( + nn.Conv2d( + conv.in_channels, + conv.out_channels, + kernel_size=conv.kernel_size, + stride=conv.stride, + padding=conv.padding, + groups=conv.groups, + bias=True, + ) + .requires_grad_(False) + .to(conv.weight.device) + ) + + # prepare filters + w_conv = conv.weight.clone().view(conv.out_channels, -1) + w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) + fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size())) + + # prepare spatial bias + b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias + b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) + fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) + + return fusedconv + + +def copy_attr(a, b, include=(), exclude=()): + # Copy attributes from b to a, options to only include [...] and to exclude [...] + for k, v in b.__dict__.items(): + if (include and k not in include) or k.startswith("_") or k in exclude: + continue + + setattr(a, k, v) diff --git a/hordelib/nodes/facerestore/facelib/parsing/__init__.py b/hordelib/nodes/facerestore_cf/facelib/parsing/__init__.py similarity index 81% rename from hordelib/nodes/facerestore/facelib/parsing/__init__.py rename to hordelib/nodes/facerestore_cf/facelib/parsing/__init__.py index b95f2d8b..f030fbfd 100644 --- a/hordelib/nodes/facerestore/facelib/parsing/__init__.py +++ b/hordelib/nodes/facerestore_cf/facelib/parsing/__init__.py @@ -1,28 +1,31 @@ -import torch - -from hordelib.nodes.facerestore.facelib.utils import load_file_from_url -from .bisenet import BiSeNet -from .parsenet import ParseNet - - -def init_parsing_model(model_name="bisenet", half=False, device="cuda"): - if model_name == "bisenet": - model = BiSeNet(num_class=19) - model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth" - elif model_name == "parsenet": - model = ParseNet(in_size=512, out_size=512, parsing_ch=19) - model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth" - else: - raise NotImplementedError(f"{model_name} is not implemented.") - - model_path = load_file_from_url( - url=model_url, - model_dir="../../models/facedetection", - progress=True, - file_name=None, - ) - load_net = torch.load(model_path, map_location=lambda storage, loc: storage) - model.load_state_dict(load_net, strict=True) - model.eval() - model = model.to(device) - return model +import torch + +from hordelib.nodes.facerestore_cf.facelib.utils import load_file_from_url + +from .bisenet import BiSeNet +from .parsenet import ParseNet + + +def init_parsing_model(model_name="bisenet", half=False, device="cuda"): + if model_name == "bisenet": + model = BiSeNet(num_class=19) + model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth" + filename = "parsing_bisenet.pth" + elif model_name == "parsenet": + model = ParseNet(in_size=512, out_size=512, parsing_ch=19) + model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth" + filename = "parsing_parsenet.pth" + else: + raise NotImplementedError(f"{model_name} is not implemented.") + + model_path = load_file_from_url( + url=model_url, + model_dir="../../models/facedetection", + progress=True, + file_name=filename, + ) + load_net = torch.load(model_path, map_location=lambda storage, loc: storage) + model.load_state_dict(load_net, strict=True) + model.eval() + model = model.to(device) + return model diff --git a/hordelib/nodes/facerestore/facelib/parsing/bisenet.py b/hordelib/nodes/facerestore_cf/facelib/parsing/bisenet.py similarity index 85% rename from hordelib/nodes/facerestore/facelib/parsing/bisenet.py rename to hordelib/nodes/facerestore_cf/facelib/parsing/bisenet.py index 3898cab7..051eec1d 100644 --- a/hordelib/nodes/facerestore/facelib/parsing/bisenet.py +++ b/hordelib/nodes/facerestore_cf/facelib/parsing/bisenet.py @@ -1,140 +1,140 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .resnet import ResNet18 - - -class ConvBNReLU(nn.Module): - - def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1): - super(ConvBNReLU, self).__init__() - self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False) - self.bn = nn.BatchNorm2d(out_chan) - - def forward(self, x): - x = self.conv(x) - x = F.relu(self.bn(x)) - return x - - -class BiSeNetOutput(nn.Module): - - def __init__(self, in_chan, mid_chan, num_class): - super(BiSeNetOutput, self).__init__() - self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) - self.conv_out = nn.Conv2d(mid_chan, num_class, kernel_size=1, bias=False) - - def forward(self, x): - feat = self.conv(x) - out = self.conv_out(feat) - return out, feat - - -class AttentionRefinementModule(nn.Module): - - def __init__(self, in_chan, out_chan): - super(AttentionRefinementModule, self).__init__() - self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) - self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False) - self.bn_atten = nn.BatchNorm2d(out_chan) - self.sigmoid_atten = nn.Sigmoid() - - def forward(self, x): - feat = self.conv(x) - atten = F.avg_pool2d(feat, feat.size()[2:]) - atten = self.conv_atten(atten) - atten = self.bn_atten(atten) - atten = self.sigmoid_atten(atten) - out = torch.mul(feat, atten) - return out - - -class ContextPath(nn.Module): - - def __init__(self): - super(ContextPath, self).__init__() - self.resnet = ResNet18() - self.arm16 = AttentionRefinementModule(256, 128) - self.arm32 = AttentionRefinementModule(512, 128) - self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) - self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) - self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) - - def forward(self, x): - feat8, feat16, feat32 = self.resnet(x) - h8, w8 = feat8.size()[2:] - h16, w16 = feat16.size()[2:] - h32, w32 = feat32.size()[2:] - - avg = F.avg_pool2d(feat32, feat32.size()[2:]) - avg = self.conv_avg(avg) - avg_up = F.interpolate(avg, (h32, w32), mode='nearest') - - feat32_arm = self.arm32(feat32) - feat32_sum = feat32_arm + avg_up - feat32_up = F.interpolate(feat32_sum, (h16, w16), mode='nearest') - feat32_up = self.conv_head32(feat32_up) - - feat16_arm = self.arm16(feat16) - feat16_sum = feat16_arm + feat32_up - feat16_up = F.interpolate(feat16_sum, (h8, w8), mode='nearest') - feat16_up = self.conv_head16(feat16_up) - - return feat8, feat16_up, feat32_up # x8, x8, x16 - - -class FeatureFusionModule(nn.Module): - - def __init__(self, in_chan, out_chan): - super(FeatureFusionModule, self).__init__() - self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) - self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False) - self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False) - self.relu = nn.ReLU(inplace=True) - self.sigmoid = nn.Sigmoid() - - def forward(self, fsp, fcp): - fcat = torch.cat([fsp, fcp], dim=1) - feat = self.convblk(fcat) - atten = F.avg_pool2d(feat, feat.size()[2:]) - atten = self.conv1(atten) - atten = self.relu(atten) - atten = self.conv2(atten) - atten = self.sigmoid(atten) - feat_atten = torch.mul(feat, atten) - feat_out = feat_atten + feat - return feat_out - - -class BiSeNet(nn.Module): - - def __init__(self, num_class): - super(BiSeNet, self).__init__() - self.cp = ContextPath() - self.ffm = FeatureFusionModule(256, 256) - self.conv_out = BiSeNetOutput(256, 256, num_class) - self.conv_out16 = BiSeNetOutput(128, 64, num_class) - self.conv_out32 = BiSeNetOutput(128, 64, num_class) - - def forward(self, x, return_feat=False): - h, w = x.size()[2:] - feat_res8, feat_cp8, feat_cp16 = self.cp(x) # return res3b1 feature - feat_sp = feat_res8 # replace spatial path feature with res3b1 feature - feat_fuse = self.ffm(feat_sp, feat_cp8) - - out, feat = self.conv_out(feat_fuse) - out16, feat16 = self.conv_out16(feat_cp8) - out32, feat32 = self.conv_out32(feat_cp16) - - out = F.interpolate(out, (h, w), mode='bilinear', align_corners=True) - out16 = F.interpolate(out16, (h, w), mode='bilinear', align_corners=True) - out32 = F.interpolate(out32, (h, w), mode='bilinear', align_corners=True) - - if return_feat: - feat = F.interpolate(feat, (h, w), mode='bilinear', align_corners=True) - feat16 = F.interpolate(feat16, (h, w), mode='bilinear', align_corners=True) - feat32 = F.interpolate(feat32, (h, w), mode='bilinear', align_corners=True) - return out, out16, out32, feat, feat16, feat32 - else: - return out, out16, out32 +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .resnet import ResNet18 + + +class ConvBNReLU(nn.Module): + + def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1): + super(ConvBNReLU, self).__init__() + self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False) + self.bn = nn.BatchNorm2d(out_chan) + + def forward(self, x): + x = self.conv(x) + x = F.relu(self.bn(x)) + return x + + +class BiSeNetOutput(nn.Module): + + def __init__(self, in_chan, mid_chan, num_class): + super(BiSeNetOutput, self).__init__() + self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) + self.conv_out = nn.Conv2d(mid_chan, num_class, kernel_size=1, bias=False) + + def forward(self, x): + feat = self.conv(x) + out = self.conv_out(feat) + return out, feat + + +class AttentionRefinementModule(nn.Module): + + def __init__(self, in_chan, out_chan): + super(AttentionRefinementModule, self).__init__() + self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) + self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False) + self.bn_atten = nn.BatchNorm2d(out_chan) + self.sigmoid_atten = nn.Sigmoid() + + def forward(self, x): + feat = self.conv(x) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv_atten(atten) + atten = self.bn_atten(atten) + atten = self.sigmoid_atten(atten) + out = torch.mul(feat, atten) + return out + + +class ContextPath(nn.Module): + + def __init__(self): + super(ContextPath, self).__init__() + self.resnet = ResNet18() + self.arm16 = AttentionRefinementModule(256, 128) + self.arm32 = AttentionRefinementModule(512, 128) + self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) + + def forward(self, x): + feat8, feat16, feat32 = self.resnet(x) + h8, w8 = feat8.size()[2:] + h16, w16 = feat16.size()[2:] + h32, w32 = feat32.size()[2:] + + avg = F.avg_pool2d(feat32, feat32.size()[2:]) + avg = self.conv_avg(avg) + avg_up = F.interpolate(avg, (h32, w32), mode="nearest") + + feat32_arm = self.arm32(feat32) + feat32_sum = feat32_arm + avg_up + feat32_up = F.interpolate(feat32_sum, (h16, w16), mode="nearest") + feat32_up = self.conv_head32(feat32_up) + + feat16_arm = self.arm16(feat16) + feat16_sum = feat16_arm + feat32_up + feat16_up = F.interpolate(feat16_sum, (h8, w8), mode="nearest") + feat16_up = self.conv_head16(feat16_up) + + return feat8, feat16_up, feat32_up # x8, x8, x16 + + +class FeatureFusionModule(nn.Module): + + def __init__(self, in_chan, out_chan): + super(FeatureFusionModule, self).__init__() + self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) + self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False) + self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False) + self.relu = nn.ReLU(inplace=True) + self.sigmoid = nn.Sigmoid() + + def forward(self, fsp, fcp): + fcat = torch.cat([fsp, fcp], dim=1) + feat = self.convblk(fcat) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv1(atten) + atten = self.relu(atten) + atten = self.conv2(atten) + atten = self.sigmoid(atten) + feat_atten = torch.mul(feat, atten) + feat_out = feat_atten + feat + return feat_out + + +class BiSeNet(nn.Module): + + def __init__(self, num_class): + super(BiSeNet, self).__init__() + self.cp = ContextPath() + self.ffm = FeatureFusionModule(256, 256) + self.conv_out = BiSeNetOutput(256, 256, num_class) + self.conv_out16 = BiSeNetOutput(128, 64, num_class) + self.conv_out32 = BiSeNetOutput(128, 64, num_class) + + def forward(self, x, return_feat=False): + h, w = x.size()[2:] + feat_res8, feat_cp8, feat_cp16 = self.cp(x) # return res3b1 feature + feat_sp = feat_res8 # replace spatial path feature with res3b1 feature + feat_fuse = self.ffm(feat_sp, feat_cp8) + + out, feat = self.conv_out(feat_fuse) + out16, feat16 = self.conv_out16(feat_cp8) + out32, feat32 = self.conv_out32(feat_cp16) + + out = F.interpolate(out, (h, w), mode="bilinear", align_corners=True) + out16 = F.interpolate(out16, (h, w), mode="bilinear", align_corners=True) + out32 = F.interpolate(out32, (h, w), mode="bilinear", align_corners=True) + + if return_feat: + feat = F.interpolate(feat, (h, w), mode="bilinear", align_corners=True) + feat16 = F.interpolate(feat16, (h, w), mode="bilinear", align_corners=True) + feat32 = F.interpolate(feat32, (h, w), mode="bilinear", align_corners=True) + return out, out16, out32, feat, feat16, feat32 + else: + return out, out16, out32 diff --git a/hordelib/nodes/facerestore/facelib/parsing/parsenet.py b/hordelib/nodes/facerestore_cf/facelib/parsing/parsenet.py similarity index 70% rename from hordelib/nodes/facerestore/facelib/parsing/parsenet.py rename to hordelib/nodes/facerestore_cf/facelib/parsing/parsenet.py index e178ebe4..b2ed997a 100644 --- a/hordelib/nodes/facerestore/facelib/parsing/parsenet.py +++ b/hordelib/nodes/facerestore_cf/facelib/parsing/parsenet.py @@ -1,194 +1,199 @@ -"""Modified from https://github.com/chaofengc/PSFRGAN -""" -import numpy as np -import torch.nn as nn -from torch.nn import functional as F - - -class NormLayer(nn.Module): - """Normalization Layers. - - Args: - channels: input channels, for batch norm and instance norm. - input_size: input shape without batch size, for layer norm. - """ - - def __init__(self, channels, normalize_shape=None, norm_type='bn'): - super(NormLayer, self).__init__() - norm_type = norm_type.lower() - self.norm_type = norm_type - if norm_type == 'bn': - self.norm = nn.BatchNorm2d(channels, affine=True) - elif norm_type == 'in': - self.norm = nn.InstanceNorm2d(channels, affine=False) - elif norm_type == 'gn': - self.norm = nn.GroupNorm(32, channels, affine=True) - elif norm_type == 'pixel': - self.norm = lambda x: F.normalize(x, p=2, dim=1) - elif norm_type == 'layer': - self.norm = nn.LayerNorm(normalize_shape) - elif norm_type == 'none': - self.norm = lambda x: x * 1.0 - else: - assert 1 == 0, f'Norm type {norm_type} not support.' - - def forward(self, x, ref=None): - if self.norm_type == 'spade': - return self.norm(x, ref) - else: - return self.norm(x) - - -class ReluLayer(nn.Module): - """Relu Layer. - - Args: - relu type: type of relu layer, candidates are - - ReLU - - LeakyReLU: default relu slope 0.2 - - PRelu - - SELU - - none: direct pass - """ - - def __init__(self, channels, relu_type='relu'): - super(ReluLayer, self).__init__() - relu_type = relu_type.lower() - if relu_type == 'relu': - self.func = nn.ReLU(True) - elif relu_type == 'leakyrelu': - self.func = nn.LeakyReLU(0.2, inplace=True) - elif relu_type == 'prelu': - self.func = nn.PReLU(channels) - elif relu_type == 'selu': - self.func = nn.SELU(True) - elif relu_type == 'none': - self.func = lambda x: x * 1.0 - else: - assert 1 == 0, f'Relu type {relu_type} not support.' - - def forward(self, x): - return self.func(x) - - -class ConvLayer(nn.Module): - - def __init__(self, - in_channels, - out_channels, - kernel_size=3, - scale='none', - norm_type='none', - relu_type='none', - use_pad=True, - bias=True): - super(ConvLayer, self).__init__() - self.use_pad = use_pad - self.norm_type = norm_type - if norm_type in ['bn']: - bias = False - - stride = 2 if scale == 'down' else 1 - - self.scale_func = lambda x: x - if scale == 'up': - self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest') - - self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.) / 2))) - self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias) - - self.relu = ReluLayer(out_channels, relu_type) - self.norm = NormLayer(out_channels, norm_type=norm_type) - - def forward(self, x): - out = self.scale_func(x) - if self.use_pad: - out = self.reflection_pad(out) - out = self.conv2d(out) - out = self.norm(out) - out = self.relu(out) - return out - - -class ResidualBlock(nn.Module): - """ - Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html - """ - - def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none'): - super(ResidualBlock, self).__init__() - - if scale == 'none' and c_in == c_out: - self.shortcut_func = lambda x: x - else: - self.shortcut_func = ConvLayer(c_in, c_out, 3, scale) - - scale_config_dict = {'down': ['none', 'down'], 'up': ['up', 'none'], 'none': ['none', 'none']} - scale_conf = scale_config_dict[scale] - - self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type) - self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type='none') - - def forward(self, x): - identity = self.shortcut_func(x) - - res = self.conv1(x) - res = self.conv2(res) - return identity + res - - -class ParseNet(nn.Module): - - def __init__(self, - in_size=128, - out_size=128, - min_feat_size=32, - base_ch=64, - parsing_ch=19, - res_depth=10, - relu_type='LeakyReLU', - norm_type='bn', - ch_range=[32, 256]): - super().__init__() - self.res_depth = res_depth - act_args = {'norm_type': norm_type, 'relu_type': relu_type} - min_ch, max_ch = ch_range - - ch_clip = lambda x: max(min_ch, min(x, max_ch)) # noqa: E731 - min_feat_size = min(in_size, min_feat_size) - - down_steps = int(np.log2(in_size // min_feat_size)) - up_steps = int(np.log2(out_size // min_feat_size)) - - # =============== define encoder-body-decoder ==================== - self.encoder = [] - self.encoder.append(ConvLayer(3, base_ch, 3, 1)) - head_ch = base_ch - for i in range(down_steps): - cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2) - self.encoder.append(ResidualBlock(cin, cout, scale='down', **act_args)) - head_ch = head_ch * 2 - - self.body = [] - for i in range(res_depth): - self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args)) - - self.decoder = [] - for i in range(up_steps): - cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2) - self.decoder.append(ResidualBlock(cin, cout, scale='up', **act_args)) - head_ch = head_ch // 2 - - self.encoder = nn.Sequential(*self.encoder) - self.body = nn.Sequential(*self.body) - self.decoder = nn.Sequential(*self.decoder) - self.out_img_conv = ConvLayer(ch_clip(head_ch), 3) - self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch) - - def forward(self, x): - feat = self.encoder(x) - x = feat + self.body(feat) - x = self.decoder(x) - out_img = self.out_img_conv(x) - out_mask = self.out_mask_conv(x) - return out_mask, out_img +"""Modified from https://github.com/chaofengc/PSFRGAN +""" + +import numpy as np +import torch.nn as nn +from torch.nn import functional as F + + +class NormLayer(nn.Module): + """Normalization Layers. + + Args: + channels: input channels, for batch norm and instance norm. + input_size: input shape without batch size, for layer norm. + """ + + def __init__(self, channels, normalize_shape=None, norm_type="bn"): + super(NormLayer, self).__init__() + norm_type = norm_type.lower() + self.norm_type = norm_type + if norm_type == "bn": + self.norm = nn.BatchNorm2d(channels, affine=True) + elif norm_type == "in": + self.norm = nn.InstanceNorm2d(channels, affine=False) + elif norm_type == "gn": + self.norm = nn.GroupNorm(32, channels, affine=True) + elif norm_type == "pixel": + self.norm = lambda x: F.normalize(x, p=2, dim=1) + elif norm_type == "layer": + self.norm = nn.LayerNorm(normalize_shape) + elif norm_type == "none": + self.norm = lambda x: x * 1.0 + else: + assert 1 == 0, f"Norm type {norm_type} not support." + + def forward(self, x, ref=None): + if self.norm_type == "spade": + return self.norm(x, ref) + else: + return self.norm(x) + + +class ReluLayer(nn.Module): + """Relu Layer. + + Args: + relu type: type of relu layer, candidates are + - ReLU + - LeakyReLU: default relu slope 0.2 + - PRelu + - SELU + - none: direct pass + """ + + def __init__(self, channels, relu_type="relu"): + super(ReluLayer, self).__init__() + relu_type = relu_type.lower() + if relu_type == "relu": + self.func = nn.ReLU(True) + elif relu_type == "leakyrelu": + self.func = nn.LeakyReLU(0.2, inplace=True) + elif relu_type == "prelu": + self.func = nn.PReLU(channels) + elif relu_type == "selu": + self.func = nn.SELU(True) + elif relu_type == "none": + self.func = lambda x: x * 1.0 + else: + assert 1 == 0, f"Relu type {relu_type} not support." + + def forward(self, x): + return self.func(x) + + +class ConvLayer(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + kernel_size=3, + scale="none", + norm_type="none", + relu_type="none", + use_pad=True, + bias=True, + ): + super(ConvLayer, self).__init__() + self.use_pad = use_pad + self.norm_type = norm_type + if norm_type in ["bn"]: + bias = False + + stride = 2 if scale == "down" else 1 + + self.scale_func = lambda x: x + if scale == "up": + self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode="nearest") + + self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.0) / 2))) + self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias) + + self.relu = ReluLayer(out_channels, relu_type) + self.norm = NormLayer(out_channels, norm_type=norm_type) + + def forward(self, x): + out = self.scale_func(x) + if self.use_pad: + out = self.reflection_pad(out) + out = self.conv2d(out) + out = self.norm(out) + out = self.relu(out) + return out + + +class ResidualBlock(nn.Module): + """ + Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html + """ + + def __init__(self, c_in, c_out, relu_type="prelu", norm_type="bn", scale="none"): + super(ResidualBlock, self).__init__() + + if scale == "none" and c_in == c_out: + self.shortcut_func = lambda x: x + else: + self.shortcut_func = ConvLayer(c_in, c_out, 3, scale) + + scale_config_dict = {"down": ["none", "down"], "up": ["up", "none"], "none": ["none", "none"]} + scale_conf = scale_config_dict[scale] + + self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type) + self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type="none") + + def forward(self, x): + identity = self.shortcut_func(x) + + res = self.conv1(x) + res = self.conv2(res) + return identity + res + + +class ParseNet(nn.Module): + + def __init__( + self, + in_size=128, + out_size=128, + min_feat_size=32, + base_ch=64, + parsing_ch=19, + res_depth=10, + relu_type="LeakyReLU", + norm_type="bn", + ch_range=[32, 256], + ): + super().__init__() + self.res_depth = res_depth + act_args = {"norm_type": norm_type, "relu_type": relu_type} + min_ch, max_ch = ch_range + + ch_clip = lambda x: max(min_ch, min(x, max_ch)) # noqa: E731 + min_feat_size = min(in_size, min_feat_size) + + down_steps = int(np.log2(in_size // min_feat_size)) + up_steps = int(np.log2(out_size // min_feat_size)) + + # =============== define encoder-body-decoder ==================== + self.encoder = [] + self.encoder.append(ConvLayer(3, base_ch, 3, 1)) + head_ch = base_ch + for i in range(down_steps): + cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2) + self.encoder.append(ResidualBlock(cin, cout, scale="down", **act_args)) + head_ch = head_ch * 2 + + self.body = [] + for i in range(res_depth): + self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args)) + + self.decoder = [] + for i in range(up_steps): + cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2) + self.decoder.append(ResidualBlock(cin, cout, scale="up", **act_args)) + head_ch = head_ch // 2 + + self.encoder = nn.Sequential(*self.encoder) + self.body = nn.Sequential(*self.body) + self.decoder = nn.Sequential(*self.decoder) + self.out_img_conv = ConvLayer(ch_clip(head_ch), 3) + self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch) + + def forward(self, x): + feat = self.encoder(x) + x = feat + self.body(feat) + x = self.decoder(x) + out_img = self.out_img_conv(x) + out_mask = self.out_mask_conv(x) + return out_mask, out_img diff --git a/hordelib/nodes/facerestore/facelib/parsing/resnet.py b/hordelib/nodes/facerestore_cf/facelib/parsing/resnet.py similarity index 97% rename from hordelib/nodes/facerestore/facelib/parsing/resnet.py rename to hordelib/nodes/facerestore_cf/facelib/parsing/resnet.py index fec8e82c..e7cc283d 100644 --- a/hordelib/nodes/facerestore/facelib/parsing/resnet.py +++ b/hordelib/nodes/facerestore_cf/facelib/parsing/resnet.py @@ -1,69 +1,69 @@ -import torch.nn as nn -import torch.nn.functional as F - - -def conv3x3(in_planes, out_planes, stride=1): - """3x3 convolution with padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) - - -class BasicBlock(nn.Module): - - def __init__(self, in_chan, out_chan, stride=1): - super(BasicBlock, self).__init__() - self.conv1 = conv3x3(in_chan, out_chan, stride) - self.bn1 = nn.BatchNorm2d(out_chan) - self.conv2 = conv3x3(out_chan, out_chan) - self.bn2 = nn.BatchNorm2d(out_chan) - self.relu = nn.ReLU(inplace=True) - self.downsample = None - if in_chan != out_chan or stride != 1: - self.downsample = nn.Sequential( - nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(out_chan), - ) - - def forward(self, x): - residual = self.conv1(x) - residual = F.relu(self.bn1(residual)) - residual = self.conv2(residual) - residual = self.bn2(residual) - - shortcut = x - if self.downsample is not None: - shortcut = self.downsample(x) - - out = shortcut + residual - out = self.relu(out) - return out - - -def create_layer_basic(in_chan, out_chan, bnum, stride=1): - layers = [BasicBlock(in_chan, out_chan, stride=stride)] - for i in range(bnum - 1): - layers.append(BasicBlock(out_chan, out_chan, stride=1)) - return nn.Sequential(*layers) - - -class ResNet18(nn.Module): - - def __init__(self): - super(ResNet18, self).__init__() - self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) - self.bn1 = nn.BatchNorm2d(64) - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) - self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) - self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) - self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) - - def forward(self, x): - x = self.conv1(x) - x = F.relu(self.bn1(x)) - x = self.maxpool(x) - - x = self.layer1(x) - feat8 = self.layer2(x) # 1/8 - feat16 = self.layer3(feat8) # 1/16 - feat32 = self.layer4(feat16) # 1/32 - return feat8, feat16, feat32 +import torch.nn as nn +import torch.nn.functional as F + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + + def __init__(self, in_chan, out_chan, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(in_chan, out_chan, stride) + self.bn1 = nn.BatchNorm2d(out_chan) + self.conv2 = conv3x3(out_chan, out_chan) + self.bn2 = nn.BatchNorm2d(out_chan) + self.relu = nn.ReLU(inplace=True) + self.downsample = None + if in_chan != out_chan or stride != 1: + self.downsample = nn.Sequential( + nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(out_chan), + ) + + def forward(self, x): + residual = self.conv1(x) + residual = F.relu(self.bn1(residual)) + residual = self.conv2(residual) + residual = self.bn2(residual) + + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(x) + + out = shortcut + residual + out = self.relu(out) + return out + + +def create_layer_basic(in_chan, out_chan, bnum, stride=1): + layers = [BasicBlock(in_chan, out_chan, stride=stride)] + for i in range(bnum - 1): + layers.append(BasicBlock(out_chan, out_chan, stride=1)) + return nn.Sequential(*layers) + + +class ResNet18(nn.Module): + + def __init__(self): + super(ResNet18, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) + self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) + self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) + self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(self.bn1(x)) + x = self.maxpool(x) + + x = self.layer1(x) + feat8 = self.layer2(x) # 1/8 + feat16 = self.layer3(feat8) # 1/16 + feat32 = self.layer4(feat16) # 1/32 + return feat8, feat16, feat32 diff --git a/hordelib/nodes/facerestore_cf/facelib/utils/__init__.py b/hordelib/nodes/facerestore_cf/facelib/utils/__init__.py new file mode 100644 index 00000000..47ef2010 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/facelib/utils/__init__.py @@ -0,0 +1,13 @@ +from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back +from .misc import download_pretrained_models, img2tensor, load_file_from_url, scandir + +__all__ = [ + "align_crop_face_landmarks", + "compute_increased_bbox", + "get_valid_bboxes", + "load_file_from_url", + "download_pretrained_models", + "paste_face_back", + "img2tensor", + "scandir", +] diff --git a/hordelib/nodes/facerestore/facelib/utils/face_restoration_helper.py b/hordelib/nodes/facerestore_cf/facelib/utils/face_restoration_helper.py similarity index 78% rename from hordelib/nodes/facerestore/facelib/utils/face_restoration_helper.py rename to hordelib/nodes/facerestore_cf/facelib/utils/face_restoration_helper.py index 7917344a..ab6ce3c0 100644 --- a/hordelib/nodes/facerestore/facelib/utils/face_restoration_helper.py +++ b/hordelib/nodes/facerestore_cf/facelib/utils/face_restoration_helper.py @@ -1,561 +1,476 @@ -import cv2 -import numpy as np -import os -import torch -from torchvision.transforms.functional import normalize - -from hordelib.nodes.facerestore.facelib.detection import init_detection_model -from hordelib.nodes.facerestore.facelib.parsing import init_parsing_model -from hordelib.nodes.facerestore.facelib.utils.misc import img2tensor, imwrite - - -def get_largest_face(det_faces, h, w): - def get_location(val, length): - if val < 0: - return 0 - elif val > length: - return length - else: - return val - - face_areas = [] - for det_face in det_faces: - left = get_location(det_face[0], w) - right = get_location(det_face[2], w) - top = get_location(det_face[1], h) - bottom = get_location(det_face[3], h) - face_area = (right - left) * (bottom - top) - face_areas.append(face_area) - largest_idx = face_areas.index(max(face_areas)) - return det_faces[largest_idx], largest_idx - - -def get_center_face(det_faces, h=0, w=0, center=None): - if center is not None: - center = np.array(center) - else: - center = np.array([w / 2, h / 2]) - center_dist = [] - for det_face in det_faces: - face_center = np.array( - [(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2] - ) - dist = np.linalg.norm(face_center - center) - center_dist.append(dist) - center_idx = center_dist.index(min(center_dist)) - return det_faces[center_idx], center_idx - - -class FaceRestoreHelper(object): - """Helper for the face restoration pipeline (base class).""" - - def __init__( - self, - upscale_factor, - face_size=512, - crop_ratio=(1, 1), - det_model="retinaface_resnet50", - save_ext="png", - template_3points=False, - pad_blur=False, - use_parse=False, - device=None, - ): - self.template_3points = template_3points # improve robustness - self.upscale_factor = upscale_factor - # the cropped face ratio based on the square face - self.crop_ratio = crop_ratio # (h, w) - assert ( - self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1 - ), "crop ration only supports >=1" - self.face_size = ( - int(face_size * self.crop_ratio[1]), - int(face_size * self.crop_ratio[0]), - ) - - if self.template_3points: - self.face_template = np.array([[192, 240], [319, 240], [257, 371]]) - else: - # standard 5 landmarks for FFHQ faces with 512 x 512 - # facexlib - self.face_template = np.array( - [ - [192.98138, 239.94708], - [318.90277, 240.1936], - [256.63416, 314.01935], - [201.26117, 371.41043], - [313.08905, 371.15118], - ] - ) - - # dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54 - # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894], - # [198.22603, 372.82502], [313.91018, 372.75659]]) - - self.face_template = self.face_template * (face_size / 512.0) - if self.crop_ratio[0] > 1: - self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2 - if self.crop_ratio[1] > 1: - self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2 - self.save_ext = save_ext - self.pad_blur = pad_blur - if self.pad_blur is True: - self.template_3points = False - - self.all_landmarks_5 = [] - self.det_faces = [] - self.affine_matrices = [] - self.inverse_affine_matrices = [] - self.cropped_faces = [] - self.restored_faces = [] - self.pad_input_imgs = [] - - if device is None: - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - else: - self.device = device - - # init face detection model - self.face_det = init_detection_model(det_model, half=False, device=self.device) - - # init face parsing model - self.use_parse = use_parse - self.face_parse = init_parsing_model(model_name="parsenet", device=self.device) - - def set_upscale_factor(self, upscale_factor): - self.upscale_factor = upscale_factor - - def read_image(self, img): - """img can be image path or cv2 loaded image.""" - # self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255] - if isinstance(img, str): - img = cv2.imread(img) - - if np.max(img) > 256: # 16-bit image - img = img / 65535 * 255 - if len(img.shape) == 2: # gray image - img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - elif img.shape[2] == 4: # BGRA image with alpha channel - img = img[:, :, 0:3] - - self.input_img = img - - if min(self.input_img.shape[:2]) < 512: - f = 512.0 / min(self.input_img.shape[:2]) - self.input_img = cv2.resize( - self.input_img, (0, 0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR - ) - - def get_face_landmarks_5( - self, - only_keep_largest=False, - only_center_face=False, - resize=None, - blur_ratio=0.01, - eye_dist_threshold=None, - ): - if resize is None: - scale = 1 - input_img = self.input_img - else: - h, w = self.input_img.shape[0:2] - scale = resize / min(h, w) - scale = max(1, scale) # always scale up - h, w = int(h * scale), int(w * scale) - interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR - input_img = cv2.resize(self.input_img, (w, h), interpolation=interp) - - with torch.no_grad(): - bboxes = self.face_det.detect_faces(input_img) - - if bboxes is None or bboxes.shape[0] == 0: - return 0 - else: - bboxes = bboxes / scale - - for bbox in bboxes: - # remove faces with too small eye distance: side faces or too small faces - eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]]) - if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold): - continue - - if self.template_3points: - landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)]) - else: - landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)]) - self.all_landmarks_5.append(landmark) - self.det_faces.append(bbox[0:5]) - - if len(self.det_faces) == 0: - return 0 - if only_keep_largest: - h, w, _ = self.input_img.shape - self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w) - self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]] - elif only_center_face: - h, w, _ = self.input_img.shape - self.det_faces, center_idx = get_center_face(self.det_faces, h, w) - self.all_landmarks_5 = [self.all_landmarks_5[center_idx]] - - # pad blurry images - if self.pad_blur: - self.pad_input_imgs = [] - for landmarks in self.all_landmarks_5: - # get landmarks - eye_left = landmarks[0, :] - eye_right = landmarks[1, :] - eye_avg = (eye_left + eye_right) * 0.5 - mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5 - eye_to_eye = eye_right - eye_left - eye_to_mouth = mouth_avg - eye_avg - - # Get the oriented crop rectangle - # x: half width of the oriented crop rectangle - x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] - # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise - # norm with the hypotenuse: get the direction - x /= np.hypot(*x) # get the hypotenuse of a right triangle - rect_scale = 1.5 - x *= max( - np.hypot(*eye_to_eye) * 2.0 * rect_scale, - np.hypot(*eye_to_mouth) * 1.8 * rect_scale, - ) - # y: half height of the oriented crop rectangle - y = np.flipud(x) * [-1, 1] - - # c: center - c = eye_avg + eye_to_mouth * 0.1 - # quad: (left_top, left_bottom, right_bottom, right_top) - quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) - # qsize: side length of the square - qsize = np.hypot(*x) * 2 - border = max(int(np.rint(qsize * 0.1)), 3) - - # get pad - # pad: (width_left, height_top, width_right, height_bottom) - pad = ( - int(np.floor(min(quad[:, 0]))), - int(np.floor(min(quad[:, 1]))), - int(np.ceil(max(quad[:, 0]))), - int(np.ceil(max(quad[:, 1]))), - ) - pad = [ - max(-pad[0] + border, 1), - max(-pad[1] + border, 1), - max(pad[2] - self.input_img.shape[0] + border, 1), - max(pad[3] - self.input_img.shape[1] + border, 1), - ] - - if max(pad) > 1: - # pad image - pad_img = np.pad( - self.input_img, - ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), - "reflect", - ) - # modify landmark coords - landmarks[:, 0] += pad[0] - landmarks[:, 1] += pad[1] - # blur pad images - h, w, _ = pad_img.shape - y, x, _ = np.ogrid[:h, :w, :1] - mask = np.maximum( - 1.0 - - np.minimum( - np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2] - ), - 1.0 - - np.minimum( - np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3] - ), - ) - blur = int(qsize * blur_ratio) - if blur % 2 == 0: - blur += 1 - blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur)) - # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0) - - pad_img = pad_img.astype("float32") - pad_img += (blur_img - pad_img) * np.clip( - mask * 3.0 + 1.0, 0.0, 1.0 - ) - pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip( - mask, 0.0, 1.0 - ) - pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255] - self.pad_input_imgs.append(pad_img) - else: - self.pad_input_imgs.append(np.copy(self.input_img)) - - return len(self.all_landmarks_5) - - def align_warp_face(self, save_cropped_path=None, border_mode="constant"): - """Align and warp faces with face template.""" - if self.pad_blur: - assert len(self.pad_input_imgs) == len( - self.all_landmarks_5 - ), f"Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}" - for idx, landmark in enumerate(self.all_landmarks_5): - # use 5 landmarks to get affine matrix - # use cv2.LMEDS method for the equivalence to skimage transform - # ref: https://blog.csdn.net/yichxi/article/details/115827338 - affine_matrix = cv2.estimateAffinePartial2D( - landmark, self.face_template, method=cv2.LMEDS - )[0] - self.affine_matrices.append(affine_matrix) - # warp and crop faces - if border_mode == "constant": - border_mode = cv2.BORDER_CONSTANT - elif border_mode == "reflect101": - border_mode = cv2.BORDER_REFLECT101 - elif border_mode == "reflect": - border_mode = cv2.BORDER_REFLECT - if self.pad_blur: - input_img = self.pad_input_imgs[idx] - else: - input_img = self.input_img - cropped_face = cv2.warpAffine( - input_img, - affine_matrix, - self.face_size, - borderMode=border_mode, - borderValue=(135, 133, 132), - ) # gray - self.cropped_faces.append(cropped_face) - # save the cropped face - if save_cropped_path is not None: - path = os.path.splitext(save_cropped_path)[0] - save_path = f"{path}_{idx:02d}.{self.save_ext}" - imwrite(cropped_face, save_path) - - def get_inverse_affine(self, save_inverse_affine_path=None): - """Get inverse affine matrix.""" - for idx, affine_matrix in enumerate(self.affine_matrices): - inverse_affine = cv2.invertAffineTransform(affine_matrix) - inverse_affine *= self.upscale_factor - self.inverse_affine_matrices.append(inverse_affine) - # save inverse affine matrices - if save_inverse_affine_path is not None: - path, _ = os.path.splitext(save_inverse_affine_path) - save_path = f"{path}_{idx:02d}.pth" - torch.save(inverse_affine, save_path) - - def add_restored_face(self, face): - self.restored_faces.append(face) - - def paste_faces_to_input_image( - self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None - ): - h, w, _ = self.input_img.shape - h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor) - - if upsample_img is None: - # simply resize the background - # upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4) - upsample_img = cv2.resize( - self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR - ) - else: - upsample_img = cv2.resize( - upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4 - ) - - assert len(self.restored_faces) == len( - self.inverse_affine_matrices - ), "length of restored_faces and affine_matrices are different." - - inv_mask_borders = [] - for restored_face, inverse_affine in zip( - self.restored_faces, self.inverse_affine_matrices - ): - if face_upsampler is not None: - restored_face = face_upsampler.enhance( - restored_face, outscale=self.upscale_factor - )[0] - inverse_affine /= self.upscale_factor - inverse_affine[:, 2] *= self.upscale_factor - face_size = ( - self.face_size[0] * self.upscale_factor, - self.face_size[1] * self.upscale_factor, - ) - else: - # Add an offset to inverse affine matrix, for more precise back alignment - if self.upscale_factor > 1: - extra_offset = 0.5 * self.upscale_factor - else: - extra_offset = 0 - inverse_affine[:, 2] += extra_offset - face_size = self.face_size - inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up)) - - # if draw_box or not self.use_parse: # use square parse maps - # mask = np.ones(face_size, dtype=np.float32) - # inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up)) - # # remove the black borders - # inv_mask_erosion = cv2.erode( - # inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8)) - # pasted_face = inv_mask_erosion[:, :, None] * inv_restored - # total_face_area = np.sum(inv_mask_erosion) # // 3 - # # add border - # if draw_box: - # h, w = face_size - # mask_border = np.ones((h, w, 3), dtype=np.float32) - # border = int(1400/np.sqrt(total_face_area)) - # mask_border[border:h-border, border:w-border,:] = 0 - # inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up)) - # inv_mask_borders.append(inv_mask_border) - # if not self.use_parse: - # # compute the fusion edge based on the area of face - # w_edge = int(total_face_area**0.5) // 20 - # erosion_radius = w_edge * 2 - # inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) - # blur_size = w_edge * 2 - # inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) - # if len(upsample_img.shape) == 2: # upsample_img is gray image - # upsample_img = upsample_img[:, :, None] - # inv_soft_mask = inv_soft_mask[:, :, None] - - # always use square mask - mask = np.ones(face_size, dtype=np.float32) - inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up)) - # remove the black borders - inv_mask_erosion = cv2.erode( - inv_mask, - np.ones( - (int(2 * self.upscale_factor), int(2 * self.upscale_factor)), - np.uint8, - ), - ) - pasted_face = inv_mask_erosion[:, :, None] * inv_restored - total_face_area = np.sum(inv_mask_erosion) # // 3 - # add border - if draw_box: - h, w = face_size - mask_border = np.ones((h, w, 3), dtype=np.float32) - border = int(1400 / np.sqrt(total_face_area)) - mask_border[border : h - border, border : w - border, :] = 0 - inv_mask_border = cv2.warpAffine( - mask_border, inverse_affine, (w_up, h_up) - ) - inv_mask_borders.append(inv_mask_border) - # compute the fusion edge based on the area of face - w_edge = int(total_face_area**0.5) // 20 - erosion_radius = w_edge * 2 - inv_mask_center = cv2.erode( - inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8) - ) - blur_size = w_edge * 2 - inv_soft_mask = cv2.GaussianBlur( - inv_mask_center, (blur_size + 1, blur_size + 1), 0 - ) - if len(upsample_img.shape) == 2: # upsample_img is gray image - upsample_img = upsample_img[:, :, None] - inv_soft_mask = inv_soft_mask[:, :, None] - - # parse mask - if self.use_parse: - # inference - face_input = cv2.resize( - restored_face, (512, 512), interpolation=cv2.INTER_LINEAR - ) - face_input = img2tensor( - face_input.astype("float32") / 255.0, bgr2rgb=True, float32=True - ) - normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) - face_input = torch.unsqueeze(face_input, 0).to(self.device) - with torch.no_grad(): - out = self.face_parse(face_input)[0] - out = out.argmax(dim=1).squeeze().cpu().numpy() - - parse_mask = np.zeros(out.shape) - MASK_COLORMAP = [ - 0, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 255, - 0, - 255, - 0, - 0, - 0, - ] - for idx, color in enumerate(MASK_COLORMAP): - parse_mask[out == idx] = color - # blur the mask - parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11) - parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11) - # remove the black borders - thres = 10 - parse_mask[:thres, :] = 0 - parse_mask[-thres:, :] = 0 - parse_mask[:, :thres] = 0 - parse_mask[:, -thres:] = 0 - parse_mask = parse_mask / 255.0 - - parse_mask = cv2.resize(parse_mask, face_size) - parse_mask = cv2.warpAffine( - parse_mask, inverse_affine, (w_up, h_up), flags=3 - ) - inv_soft_parse_mask = parse_mask[:, :, None] - # pasted_face = inv_restored - fuse_mask = (inv_soft_parse_mask < inv_soft_mask).astype("int") - inv_soft_mask = inv_soft_parse_mask * fuse_mask + inv_soft_mask * ( - 1 - fuse_mask - ) - - if ( - len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4 - ): # alpha channel - alpha = upsample_img[:, :, 3:] - upsample_img = ( - inv_soft_mask * pasted_face - + (1 - inv_soft_mask) * upsample_img[:, :, 0:3] - ) - upsample_img = np.concatenate((upsample_img, alpha), axis=2) - else: - upsample_img = ( - inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img - ) - - if np.max(upsample_img) > 256: # 16-bit image - upsample_img = upsample_img.astype(np.uint16) - else: - upsample_img = upsample_img.astype(np.uint8) - - # draw bounding box - if draw_box: - # upsample_input_img = cv2.resize(input_img, (w_up, h_up)) - img_color = np.ones([*upsample_img.shape], dtype=np.float32) - img_color[:, :, 0] = 0 - img_color[:, :, 1] = 255 - img_color[:, :, 2] = 0 - for inv_mask_border in inv_mask_borders: - upsample_img = ( - inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img - ) - # upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img - - if save_path is not None: - path = os.path.splitext(save_path)[0] - save_path = f"{path}.{self.save_ext}" - imwrite(upsample_img, save_path) - return upsample_img - - def clean_all(self): - self.all_landmarks_5 = [] - self.restored_faces = [] - self.affine_matrices = [] - self.cropped_faces = [] - self.inverse_affine_matrices = [] - self.det_faces = [] - self.pad_input_imgs = [] +import os + +import cv2 +import numpy as np +import torch +from torchvision.transforms.functional import normalize + +from hordelib.nodes.facerestore_cf.facelib.detection import init_detection_model +from hordelib.nodes.facerestore_cf.facelib.parsing import init_parsing_model +from hordelib.nodes.facerestore_cf.facelib.utils.misc import img2tensor, imwrite + + +def get_largest_face(det_faces, h, w): + + def get_location(val, length): + if val < 0: + return 0 + elif val > length: + return length + else: + return val + + face_areas = [] + for det_face in det_faces: + left = get_location(det_face[0], w) + right = get_location(det_face[2], w) + top = get_location(det_face[1], h) + bottom = get_location(det_face[3], h) + face_area = (right - left) * (bottom - top) + face_areas.append(face_area) + largest_idx = face_areas.index(max(face_areas)) + return det_faces[largest_idx], largest_idx + + +def get_center_face(det_faces, h=0, w=0, center=None): + if center is not None: + center = np.array(center) + else: + center = np.array([w / 2, h / 2]) + center_dist = [] + for det_face in det_faces: + face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2]) + dist = np.linalg.norm(face_center - center) + center_dist.append(dist) + center_idx = center_dist.index(min(center_dist)) + return det_faces[center_idx], center_idx + + +class FaceRestoreHelper: + """Helper for the face restoration pipeline (base class).""" + + def __init__( + self, + upscale_factor, + face_size=512, + crop_ratio=(1, 1), + det_model="retinaface_resnet50", + save_ext="png", + template_3points=False, + pad_blur=False, + use_parse=False, + device=None, + ): + self.template_3points = template_3points # improve robustness + self.upscale_factor = upscale_factor + # the cropped face ratio based on the square face + self.crop_ratio = crop_ratio # (h, w) + assert self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1, "crop ration only supports >=1" + self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0])) + + if self.template_3points: + self.face_template = np.array([[192, 240], [319, 240], [257, 371]]) + else: + # standard 5 landmarks for FFHQ faces with 512 x 512 + # facexlib + self.face_template = np.array( + [ + [192.98138, 239.94708], + [318.90277, 240.1936], + [256.63416, 314.01935], + [201.26117, 371.41043], + [313.08905, 371.15118], + ], + ) + + # dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54 + # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894], + # [198.22603, 372.82502], [313.91018, 372.75659]]) + + self.face_template = self.face_template * (face_size / 512.0) + if self.crop_ratio[0] > 1: + self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2 + if self.crop_ratio[1] > 1: + self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2 + self.save_ext = save_ext + self.pad_blur = pad_blur + if self.pad_blur is True: + self.template_3points = False + + self.all_landmarks_5 = [] + self.det_faces = [] + self.affine_matrices = [] + self.inverse_affine_matrices = [] + self.cropped_faces = [] + self.restored_faces = [] + self.pad_input_imgs = [] + + if device is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.device = device + + # init face detection model + self.face_det = init_detection_model(det_model, half=False, device=self.device) + + # init face parsing model + self.use_parse = use_parse + self.face_parse = init_parsing_model(model_name="parsenet", device=self.device) + + def set_upscale_factor(self, upscale_factor): + self.upscale_factor = upscale_factor + + def read_image(self, img): + """img can be image path or cv2 loaded image.""" + # self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255] + if isinstance(img, str): + img = cv2.imread(img) + + if np.max(img) > 256: # 16-bit image + img = img / 65535 * 255 + if len(img.shape) == 2: # gray image + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + elif img.shape[2] == 4: # BGRA image with alpha channel + img = img[:, :, 0:3] + + self.input_img = img + + if min(self.input_img.shape[:2]) < 512: + f = 512.0 / min(self.input_img.shape[:2]) + self.input_img = cv2.resize(self.input_img, (0, 0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR) + + def get_face_landmarks_5( + self, + only_keep_largest=False, + only_center_face=False, + resize=None, + blur_ratio=0.01, + eye_dist_threshold=None, + ): + if resize is None: + scale = 1 + input_img = self.input_img + else: + h, w = self.input_img.shape[0:2] + scale = resize / min(h, w) + scale = max(1, scale) # always scale up + h, w = int(h * scale), int(w * scale) + interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR + input_img = cv2.resize(self.input_img, (w, h), interpolation=interp) + + with torch.no_grad(): + bboxes = self.face_det.detect_faces(input_img) + + if bboxes is None or bboxes.shape[0] == 0: + return 0 + else: + bboxes = bboxes / scale + + for bbox in bboxes: + # remove faces with too small eye distance: side faces or too small faces + eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]]) + if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold): + continue + + if self.template_3points: + landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)]) + else: + landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)]) + self.all_landmarks_5.append(landmark) + self.det_faces.append(bbox[0:5]) + + if len(self.det_faces) == 0: + return 0 + if only_keep_largest: + h, w, _ = self.input_img.shape + self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w) + self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]] + elif only_center_face: + h, w, _ = self.input_img.shape + self.det_faces, center_idx = get_center_face(self.det_faces, h, w) + self.all_landmarks_5 = [self.all_landmarks_5[center_idx]] + + # pad blurry images + if self.pad_blur: + self.pad_input_imgs = [] + for landmarks in self.all_landmarks_5: + # get landmarks + eye_left = landmarks[0, :] + eye_right = landmarks[1, :] + eye_avg = (eye_left + eye_right) * 0.5 + mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5 + eye_to_eye = eye_right - eye_left + eye_to_mouth = mouth_avg - eye_avg + + # Get the oriented crop rectangle + # x: half width of the oriented crop rectangle + x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] + # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise + # norm with the hypotenuse: get the direction + x /= np.hypot(*x) # get the hypotenuse of a right triangle + rect_scale = 1.5 + x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale) + # y: half height of the oriented crop rectangle + y = np.flipud(x) * [-1, 1] + + # c: center + c = eye_avg + eye_to_mouth * 0.1 + # quad: (left_top, left_bottom, right_bottom, right_top) + quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) + # qsize: side length of the square + qsize = np.hypot(*x) * 2 + border = max(int(np.rint(qsize * 0.1)), 3) + + # get pad + # pad: (width_left, height_top, width_right, height_bottom) + pad = ( + int(np.floor(min(quad[:, 0]))), + int(np.floor(min(quad[:, 1]))), + int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1]))), + ) + pad = [ + max(-pad[0] + border, 1), + max(-pad[1] + border, 1), + max(pad[2] - self.input_img.shape[0] + border, 1), + max(pad[3] - self.input_img.shape[1] + border, 1), + ] + + if max(pad) > 1: + # pad image + pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), "reflect") + # modify landmark coords + landmarks[:, 0] += pad[0] + landmarks[:, 1] += pad[1] + # blur pad images + h, w, _ = pad_img.shape + y, x, _ = np.ogrid[:h, :w, :1] + mask = np.maximum( + 1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]), + 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]), + ) + blur = int(qsize * blur_ratio) + if blur % 2 == 0: + blur += 1 + blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur)) + # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0) + + pad_img = pad_img.astype("float32") + pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) + pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0) + pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255] + self.pad_input_imgs.append(pad_img) + else: + self.pad_input_imgs.append(np.copy(self.input_img)) + + return len(self.all_landmarks_5) + + def align_warp_face(self, save_cropped_path=None, border_mode="constant"): + """Align and warp faces with face template.""" + if self.pad_blur: + assert len(self.pad_input_imgs) == len( + self.all_landmarks_5, + ), f"Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}" + for idx, landmark in enumerate(self.all_landmarks_5): + # use 5 landmarks to get affine matrix + # use cv2.LMEDS method for the equivalence to skimage transform + # ref: https://blog.csdn.net/yichxi/article/details/115827338 + affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0] + self.affine_matrices.append(affine_matrix) + # warp and crop faces + if border_mode == "constant": + border_mode = cv2.BORDER_CONSTANT + elif border_mode == "reflect101": + border_mode = cv2.BORDER_REFLECT101 + elif border_mode == "reflect": + border_mode = cv2.BORDER_REFLECT + if self.pad_blur: + input_img = self.pad_input_imgs[idx] + else: + input_img = self.input_img + cropped_face = cv2.warpAffine( + input_img, + affine_matrix, + self.face_size, + borderMode=border_mode, + borderValue=(135, 133, 132), + ) # gray + self.cropped_faces.append(cropped_face) + # save the cropped face + if save_cropped_path is not None: + path = os.path.splitext(save_cropped_path)[0] + save_path = f"{path}_{idx:02d}.{self.save_ext}" + imwrite(cropped_face, save_path) + + def get_inverse_affine(self, save_inverse_affine_path=None): + """Get inverse affine matrix.""" + for idx, affine_matrix in enumerate(self.affine_matrices): + inverse_affine = cv2.invertAffineTransform(affine_matrix) + inverse_affine *= self.upscale_factor + self.inverse_affine_matrices.append(inverse_affine) + # save inverse affine matrices + if save_inverse_affine_path is not None: + path, _ = os.path.splitext(save_inverse_affine_path) + save_path = f"{path}_{idx:02d}.pth" + torch.save(inverse_affine, save_path) + + def add_restored_face(self, face): + self.restored_faces.append(face) + + def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None): + h, w, _ = self.input_img.shape + h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor) + + if upsample_img is None: + # simply resize the background + # upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4) + upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR) + else: + upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4) + + assert len(self.restored_faces) == len( + self.inverse_affine_matrices, + ), "length of restored_faces and affine_matrices are different." + + inv_mask_borders = [] + for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices, strict=False): + if face_upsampler is not None: + restored_face = face_upsampler.enhance(restored_face, outscale=self.upscale_factor)[0] + inverse_affine /= self.upscale_factor + inverse_affine[:, 2] *= self.upscale_factor + face_size = (self.face_size[0] * self.upscale_factor, self.face_size[1] * self.upscale_factor) + else: + # Add an offset to inverse affine matrix, for more precise back alignment + if self.upscale_factor > 1: + extra_offset = 0.5 * self.upscale_factor + else: + extra_offset = 0 + inverse_affine[:, 2] += extra_offset + face_size = self.face_size + inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up)) + + # if draw_box or not self.use_parse: # use square parse maps + # mask = np.ones(face_size, dtype=np.float32) + # inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up)) + # # remove the black borders + # inv_mask_erosion = cv2.erode( + # inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8)) + # pasted_face = inv_mask_erosion[:, :, None] * inv_restored + # total_face_area = np.sum(inv_mask_erosion) # // 3 + # # add border + # if draw_box: + # h, w = face_size + # mask_border = np.ones((h, w, 3), dtype=np.float32) + # border = int(1400/np.sqrt(total_face_area)) + # mask_border[border:h-border, border:w-border,:] = 0 + # inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up)) + # inv_mask_borders.append(inv_mask_border) + # if not self.use_parse: + # # compute the fusion edge based on the area of face + # w_edge = int(total_face_area**0.5) // 20 + # erosion_radius = w_edge * 2 + # inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) + # blur_size = w_edge * 2 + # inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) + # if len(upsample_img.shape) == 2: # upsample_img is gray image + # upsample_img = upsample_img[:, :, None] + # inv_soft_mask = inv_soft_mask[:, :, None] + + # always use square mask + mask = np.ones(face_size, dtype=np.float32) + inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up)) + # remove the black borders + inv_mask_erosion = cv2.erode( + inv_mask, + np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8), + ) + pasted_face = inv_mask_erosion[:, :, None] * inv_restored + total_face_area = np.sum(inv_mask_erosion) # // 3 + # add border + if draw_box: + h, w = face_size + mask_border = np.ones((h, w, 3), dtype=np.float32) + border = int(1400 / np.sqrt(total_face_area)) + mask_border[border : h - border, border : w - border, :] = 0 + inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up)) + inv_mask_borders.append(inv_mask_border) + # compute the fusion edge based on the area of face + w_edge = int(total_face_area**0.5) // 20 + erosion_radius = w_edge * 2 + inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) + blur_size = w_edge * 2 + inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) + if len(upsample_img.shape) == 2: # upsample_img is gray image + upsample_img = upsample_img[:, :, None] + inv_soft_mask = inv_soft_mask[:, :, None] + + # parse mask + if self.use_parse: + # inference + face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR) + face_input = img2tensor(face_input.astype("float32") / 255.0, bgr2rgb=True, float32=True) + normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + face_input = torch.unsqueeze(face_input, 0).to(self.device) + with torch.no_grad(): + out = self.face_parse(face_input)[0] + out = out.argmax(dim=1).squeeze().cpu().numpy() + + parse_mask = np.zeros(out.shape) + MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0] + for idx, color in enumerate(MASK_COLORMAP): + parse_mask[out == idx] = color + # blur the mask + parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11) + parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11) + # remove the black borders + thres = 10 + parse_mask[:thres, :] = 0 + parse_mask[-thres:, :] = 0 + parse_mask[:, :thres] = 0 + parse_mask[:, -thres:] = 0 + parse_mask = parse_mask / 255.0 + + parse_mask = cv2.resize(parse_mask, face_size) + parse_mask = cv2.warpAffine(parse_mask, inverse_affine, (w_up, h_up), flags=3) + inv_soft_parse_mask = parse_mask[:, :, None] + # pasted_face = inv_restored + fuse_mask = (inv_soft_parse_mask < inv_soft_mask).astype("int") + inv_soft_mask = inv_soft_parse_mask * fuse_mask + inv_soft_mask * (1 - fuse_mask) + + if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel + alpha = upsample_img[:, :, 3:] + upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3] + upsample_img = np.concatenate((upsample_img, alpha), axis=2) + else: + upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img + + if np.max(upsample_img) > 256: # 16-bit image + upsample_img = upsample_img.astype(np.uint16) + else: + upsample_img = upsample_img.astype(np.uint8) + + # draw bounding box + if draw_box: + # upsample_input_img = cv2.resize(input_img, (w_up, h_up)) + img_color = np.ones([*upsample_img.shape], dtype=np.float32) + img_color[:, :, 0] = 0 + img_color[:, :, 1] = 255 + img_color[:, :, 2] = 0 + for inv_mask_border in inv_mask_borders: + upsample_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img + # upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img + + if save_path is not None: + path = os.path.splitext(save_path)[0] + save_path = f"{path}.{self.save_ext}" + imwrite(upsample_img, save_path) + return upsample_img + + def clean_all(self): + self.all_landmarks_5 = [] + self.restored_faces = [] + self.affine_matrices = [] + self.cropped_faces = [] + self.inverse_affine_matrices = [] + self.det_faces = [] + self.pad_input_imgs = [] diff --git a/hordelib/nodes/facerestore/facelib/utils/face_utils.py b/hordelib/nodes/facerestore_cf/facelib/utils/face_utils.py similarity index 93% rename from hordelib/nodes/facerestore/facelib/utils/face_utils.py rename to hordelib/nodes/facerestore_cf/facelib/utils/face_utils.py index 5ee39570..5769a0f7 100644 --- a/hordelib/nodes/facerestore/facelib/utils/face_utils.py +++ b/hordelib/nodes/facerestore_cf/facelib/utils/face_utils.py @@ -1,283 +1,275 @@ -import cv2 -import numpy as np -import torch - - -def compute_increased_bbox(bbox, increase_area, preserve_aspect=True): - left, top, right, bot = bbox - width = right - left - height = bot - top - - if preserve_aspect: - width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width)) - height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height)) - else: - width_increase = height_increase = increase_area - left = int(left - width_increase * width) - top = int(top - height_increase * height) - right = int(right + width_increase * width) - bot = int(bot + height_increase * height) - return (left, top, right, bot) - - -def get_valid_bboxes(bboxes, h, w): - left = max(bboxes[0], 0) - top = max(bboxes[1], 0) - right = min(bboxes[2], w) - bottom = min(bboxes[3], h) - return (left, top, right, bottom) - - -def align_crop_face_landmarks( - img, - landmarks, - output_size, - transform_size=None, - enable_padding=True, - return_inverse_affine=False, - shrink_ratio=(1, 1), -): - """Align and crop face with landmarks. - - The output_size and transform_size are based on width. The height is - adjusted based on shrink_ratio_h/shring_ration_w. - - Modified from: - https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py - - Args: - img (Numpy array): Input image. - landmarks (Numpy array): 5 or 68 or 98 landmarks. - output_size (int): Output face size. - transform_size (ing): Transform size. Usually the four time of - output_size. - enable_padding (float): Default: True. - shrink_ratio (float | tuple[float] | list[float]): Shring the whole - face for height and width (crop larger area). Default: (1, 1). - - Returns: - (Numpy array): Cropped face. - """ - lm_type = "retinaface_5" # Options: dlib_5, retinaface_5 - - if isinstance(shrink_ratio, (float, int)): - shrink_ratio = (shrink_ratio, shrink_ratio) - if transform_size is None: - transform_size = output_size * 4 - - # Parse landmarks - lm = np.array(landmarks) - if lm.shape[0] == 5 and lm_type == "retinaface_5": - eye_left = lm[0] - eye_right = lm[1] - mouth_avg = (lm[3] + lm[4]) * 0.5 - elif lm.shape[0] == 5 and lm_type == "dlib_5": - lm_eye_left = lm[2:4] - lm_eye_right = lm[0:2] - eye_left = np.mean(lm_eye_left, axis=0) - eye_right = np.mean(lm_eye_right, axis=0) - mouth_avg = lm[4] - elif lm.shape[0] == 68: - lm_eye_left = lm[36:42] - lm_eye_right = lm[42:48] - eye_left = np.mean(lm_eye_left, axis=0) - eye_right = np.mean(lm_eye_right, axis=0) - mouth_avg = (lm[48] + lm[54]) * 0.5 - elif lm.shape[0] == 98: - lm_eye_left = lm[60:68] - lm_eye_right = lm[68:76] - eye_left = np.mean(lm_eye_left, axis=0) - eye_right = np.mean(lm_eye_right, axis=0) - mouth_avg = (lm[76] + lm[82]) * 0.5 - - eye_avg = (eye_left + eye_right) * 0.5 - eye_to_eye = eye_right - eye_left - eye_to_mouth = mouth_avg - eye_avg - - # Get the oriented crop rectangle - # x: half width of the oriented crop rectangle - x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] - # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise - # norm with the hypotenuse: get the direction - x /= np.hypot(*x) # get the hypotenuse of a right triangle - rect_scale = 1 # TODO: you can edit it to get larger rect - x *= max( - np.hypot(*eye_to_eye) * 2.0 * rect_scale, - np.hypot(*eye_to_mouth) * 1.8 * rect_scale, - ) - # y: half height of the oriented crop rectangle - y = np.flipud(x) * [-1, 1] - - x *= shrink_ratio[1] # width - y *= shrink_ratio[0] # height - - # c: center - c = eye_avg + eye_to_mouth * 0.1 - # quad: (left_top, left_bottom, right_bottom, right_top) - quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) - # qsize: side length of the square - qsize = np.hypot(*x) * 2 - - quad_ori = np.copy(quad) - # Shrink, for large face - # TODO: do we really need shrink - shrink = int(np.floor(qsize / output_size * 0.5)) - if shrink > 1: - h, w = img.shape[0:2] - rsize = (int(np.rint(float(w) / shrink)), int(np.rint(float(h) / shrink))) - img = cv2.resize(img, rsize, interpolation=cv2.INTER_AREA) - quad /= shrink - qsize /= shrink - - # Crop - h, w = img.shape[0:2] - border = max(int(np.rint(qsize * 0.1)), 3) - crop = ( - int(np.floor(min(quad[:, 0]))), - int(np.floor(min(quad[:, 1]))), - int(np.ceil(max(quad[:, 0]))), - int(np.ceil(max(quad[:, 1]))), - ) - crop = ( - max(crop[0] - border, 0), - max(crop[1] - border, 0), - min(crop[2] + border, w), - min(crop[3] + border, h), - ) - if crop[2] - crop[0] < w or crop[3] - crop[1] < h: - img = img[crop[1] : crop[3], crop[0] : crop[2], :] - quad -= crop[0:2] - - # Pad - # pad: (width_left, height_top, width_right, height_bottom) - h, w = img.shape[0:2] - pad = ( - int(np.floor(min(quad[:, 0]))), - int(np.floor(min(quad[:, 1]))), - int(np.ceil(max(quad[:, 0]))), - int(np.ceil(max(quad[:, 1]))), - ) - pad = ( - max(-pad[0] + border, 0), - max(-pad[1] + border, 0), - max(pad[2] - w + border, 0), - max(pad[3] - h + border, 0), - ) - if enable_padding and max(pad) > border - 4: - pad = np.maximum(pad, int(np.rint(qsize * 0.3))) - img = np.pad(img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), "reflect") - h, w = img.shape[0:2] - y, x, _ = np.ogrid[:h, :w, :1] - mask = np.maximum( - 1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]), - 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]), - ) - blur = int(qsize * 0.02) - if blur % 2 == 0: - blur += 1 - blur_img = cv2.boxFilter(img, 0, ksize=(blur, blur)) - - img = img.astype("float32") - img += (blur_img - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) - img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) - img = np.clip(img, 0, 255) # float32, [0, 255] - quad += pad[:2] - - # Transform use cv2 - h_ratio = shrink_ratio[0] / shrink_ratio[1] - dst_h, dst_w = int(transform_size * h_ratio), transform_size - template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]]) - # use cv2.LMEDS method for the equivalence to skimage transform - # ref: https://blog.csdn.net/yichxi/article/details/115827338 - affine_matrix = cv2.estimateAffinePartial2D(quad, template, method=cv2.LMEDS)[0] - cropped_face = cv2.warpAffine( - img, - affine_matrix, - (dst_w, dst_h), - borderMode=cv2.BORDER_CONSTANT, - borderValue=(135, 133, 132), - ) # gray - - if output_size < transform_size: - cropped_face = cv2.resize( - cropped_face, - (output_size, int(output_size * h_ratio)), - interpolation=cv2.INTER_LINEAR, - ) - - if return_inverse_affine: - dst_h, dst_w = int(output_size * h_ratio), output_size - template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]]) - # use cv2.LMEDS method for the equivalence to skimage transform - # ref: https://blog.csdn.net/yichxi/article/details/115827338 - affine_matrix = cv2.estimateAffinePartial2D( - quad_ori, - np.array([[0, 0], [0, output_size], [dst_w, dst_h], [dst_w, 0]]), - method=cv2.LMEDS, - )[0] - inverse_affine = cv2.invertAffineTransform(affine_matrix) - else: - inverse_affine = None - return cropped_face, inverse_affine - - -def paste_face_back(img, face, inverse_affine): - h, w = img.shape[0:2] - face_h, face_w = face.shape[0:2] - inv_restored = cv2.warpAffine(face, inverse_affine, (w, h)) - mask = np.ones((face_h, face_w, 3), dtype=np.float32) - inv_mask = cv2.warpAffine(mask, inverse_affine, (w, h)) - # remove the black borders - inv_mask_erosion = cv2.erode(inv_mask, np.ones((2, 2), np.uint8)) - inv_restored_remove_border = inv_mask_erosion * inv_restored - total_face_area = np.sum(inv_mask_erosion) // 3 - # compute the fusion edge based on the area of face - w_edge = int(total_face_area**0.5) // 20 - erosion_radius = w_edge * 2 - inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) - blur_size = w_edge * 2 - inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) - img = inv_soft_mask * inv_restored_remove_border + (1 - inv_soft_mask) * img - # float32, [0, 255] - return img - - -if __name__ == "__main__": - import os - - from hordelib.nodes.facerestore.facelib.detection import init_detection_model - from hordelib.nodes.facerestore.facelib.utils.face_restoration_helper import get_largest_face - - img_path = "/home/wxt/datasets/ffhq/ffhq_wild/00009.png" - img_name = os.path.splitext(os.path.basename(img_path))[0] - - # initialize model - det_net = init_detection_model("retinaface_resnet50", half=False) - img_ori = cv2.imread(img_path) - h, w = img_ori.shape[0:2] - # if larger than 800, scale it - scale = max(h / 800, w / 800) - if scale > 1: - img = cv2.resize(img_ori, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_LINEAR) - - with torch.no_grad(): - bboxes = det_net.detect_faces(img, 0.97) - if scale > 1: - bboxes *= scale # the score is incorrect - bboxes = get_largest_face(bboxes, h, w)[0] - - landmarks = np.array([[bboxes[i], bboxes[i + 1]] for i in range(5, 15, 2)]) - - cropped_face, inverse_affine = align_crop_face_landmarks( - img_ori, - landmarks, - output_size=512, - transform_size=None, - enable_padding=True, - return_inverse_affine=True, - shrink_ratio=(1, 1), - ) - - cv2.imwrite(f"tmp/{img_name}_cropeed_face.png", cropped_face) - img = paste_face_back(img_ori, cropped_face, inverse_affine) - cv2.imwrite(f"tmp/{img_name}_back.png", img) +import cv2 +import numpy as np +import torch + + +def compute_increased_bbox(bbox, increase_area, preserve_aspect=True): + left, top, right, bot = bbox + width = right - left + height = bot - top + + if preserve_aspect: + width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width)) + height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height)) + else: + width_increase = height_increase = increase_area + left = int(left - width_increase * width) + top = int(top - height_increase * height) + right = int(right + width_increase * width) + bot = int(bot + height_increase * height) + return (left, top, right, bot) + + +def get_valid_bboxes(bboxes, h, w): + left = max(bboxes[0], 0) + top = max(bboxes[1], 0) + right = min(bboxes[2], w) + bottom = min(bboxes[3], h) + return (left, top, right, bottom) + + +def align_crop_face_landmarks( + img, + landmarks, + output_size, + transform_size=None, + enable_padding=True, + return_inverse_affine=False, + shrink_ratio=(1, 1), +): + """Align and crop face with landmarks. + + The output_size and transform_size are based on width. The height is + adjusted based on shrink_ratio_h/shring_ration_w. + + Modified from: + https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py + + Args: + img (Numpy array): Input image. + landmarks (Numpy array): 5 or 68 or 98 landmarks. + output_size (int): Output face size. + transform_size (ing): Transform size. Usually the four time of + output_size. + enable_padding (float): Default: True. + shrink_ratio (float | tuple[float] | list[float]): Shring the whole + face for height and width (crop larger area). Default: (1, 1). + + Returns: + (Numpy array): Cropped face. + """ + lm_type = "retinaface_5" # Options: dlib_5, retinaface_5 + + if isinstance(shrink_ratio, (float, int)): + shrink_ratio = (shrink_ratio, shrink_ratio) + if transform_size is None: + transform_size = output_size * 4 + + # Parse landmarks + lm = np.array(landmarks) + if lm.shape[0] == 5 and lm_type == "retinaface_5": + eye_left = lm[0] + eye_right = lm[1] + mouth_avg = (lm[3] + lm[4]) * 0.5 + elif lm.shape[0] == 5 and lm_type == "dlib_5": + lm_eye_left = lm[2:4] + lm_eye_right = lm[0:2] + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + mouth_avg = lm[4] + elif lm.shape[0] == 68: + lm_eye_left = lm[36:42] + lm_eye_right = lm[42:48] + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + mouth_avg = (lm[48] + lm[54]) * 0.5 + elif lm.shape[0] == 98: + lm_eye_left = lm[60:68] + lm_eye_right = lm[68:76] + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + mouth_avg = (lm[76] + lm[82]) * 0.5 + + eye_avg = (eye_left + eye_right) * 0.5 + eye_to_eye = eye_right - eye_left + eye_to_mouth = mouth_avg - eye_avg + + # Get the oriented crop rectangle + # x: half width of the oriented crop rectangle + x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] + # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise + # norm with the hypotenuse: get the direction + x /= np.hypot(*x) # get the hypotenuse of a right triangle + rect_scale = 1 # TODO: you can edit it to get larger rect + x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale) + # y: half height of the oriented crop rectangle + y = np.flipud(x) * [-1, 1] + + x *= shrink_ratio[1] # width + y *= shrink_ratio[0] # height + + # c: center + c = eye_avg + eye_to_mouth * 0.1 + # quad: (left_top, left_bottom, right_bottom, right_top) + quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) + # qsize: side length of the square + qsize = np.hypot(*x) * 2 + + quad_ori = np.copy(quad) + # Shrink, for large face + # TODO: do we really need shrink + shrink = int(np.floor(qsize / output_size * 0.5)) + if shrink > 1: + h, w = img.shape[0:2] + rsize = (int(np.rint(float(w) / shrink)), int(np.rint(float(h) / shrink))) + img = cv2.resize(img, rsize, interpolation=cv2.INTER_AREA) + quad /= shrink + qsize /= shrink + + # Crop + h, w = img.shape[0:2] + border = max(int(np.rint(qsize * 0.1)), 3) + crop = ( + int(np.floor(min(quad[:, 0]))), + int(np.floor(min(quad[:, 1]))), + int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1]))), + ) + crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, w), min(crop[3] + border, h)) + if crop[2] - crop[0] < w or crop[3] - crop[1] < h: + img = img[crop[1] : crop[3], crop[0] : crop[2], :] + quad -= crop[0:2] + + # Pad + # pad: (width_left, height_top, width_right, height_bottom) + h, w = img.shape[0:2] + pad = ( + int(np.floor(min(quad[:, 0]))), + int(np.floor(min(quad[:, 1]))), + int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1]))), + ) + pad = ( + max(-pad[0] + border, 0), + max(-pad[1] + border, 0), + max(pad[2] - w + border, 0), + max(pad[3] - h + border, 0), + ) + if enable_padding and max(pad) > border - 4: + pad = np.maximum(pad, int(np.rint(qsize * 0.3))) + img = np.pad(img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), "reflect") + h, w = img.shape[0:2] + y, x, _ = np.ogrid[:h, :w, :1] + mask = np.maximum( + 1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]), + 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]), + ) + blur = int(qsize * 0.02) + if blur % 2 == 0: + blur += 1 + blur_img = cv2.boxFilter(img, 0, ksize=(blur, blur)) + + img = img.astype("float32") + img += (blur_img - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) + img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) + img = np.clip(img, 0, 255) # float32, [0, 255] + quad += pad[:2] + + # Transform use cv2 + h_ratio = shrink_ratio[0] / shrink_ratio[1] + dst_h, dst_w = int(transform_size * h_ratio), transform_size + template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]]) + # use cv2.LMEDS method for the equivalence to skimage transform + # ref: https://blog.csdn.net/yichxi/article/details/115827338 + affine_matrix = cv2.estimateAffinePartial2D(quad, template, method=cv2.LMEDS)[0] + cropped_face = cv2.warpAffine( + img, + affine_matrix, + (dst_w, dst_h), + borderMode=cv2.BORDER_CONSTANT, + borderValue=(135, 133, 132), + ) # gray + + if output_size < transform_size: + cropped_face = cv2.resize( + cropped_face, + (output_size, int(output_size * h_ratio)), + interpolation=cv2.INTER_LINEAR, + ) + + if return_inverse_affine: + dst_h, dst_w = int(output_size * h_ratio), output_size + template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]]) + # use cv2.LMEDS method for the equivalence to skimage transform + # ref: https://blog.csdn.net/yichxi/article/details/115827338 + affine_matrix = cv2.estimateAffinePartial2D( + quad_ori, + np.array([[0, 0], [0, output_size], [dst_w, dst_h], [dst_w, 0]]), + method=cv2.LMEDS, + )[0] + inverse_affine = cv2.invertAffineTransform(affine_matrix) + else: + inverse_affine = None + return cropped_face, inverse_affine + + +def paste_face_back(img, face, inverse_affine): + h, w = img.shape[0:2] + face_h, face_w = face.shape[0:2] + inv_restored = cv2.warpAffine(face, inverse_affine, (w, h)) + mask = np.ones((face_h, face_w, 3), dtype=np.float32) + inv_mask = cv2.warpAffine(mask, inverse_affine, (w, h)) + # remove the black borders + inv_mask_erosion = cv2.erode(inv_mask, np.ones((2, 2), np.uint8)) + inv_restored_remove_border = inv_mask_erosion * inv_restored + total_face_area = np.sum(inv_mask_erosion) // 3 + # compute the fusion edge based on the area of face + w_edge = int(total_face_area**0.5) // 20 + erosion_radius = w_edge * 2 + inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) + blur_size = w_edge * 2 + inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) + img = inv_soft_mask * inv_restored_remove_border + (1 - inv_soft_mask) * img + # float32, [0, 255] + return img + + +if __name__ == "__main__": + import os + + from hordelib.nodes.facerestore_cf.facelib.detection import init_detection_model + from hordelib.nodes.facerestore_cf.facelib.utils.face_restoration_helper import get_largest_face + + img_path = "/home/wxt/datasets/ffhq/ffhq_wild/00009.png" + img_name = os.path.splitext(os.path.basename(img_path))[0] + + # initialize model + det_net = init_detection_model("retinaface_resnet50", half=False) + img_ori = cv2.imread(img_path) + h, w = img_ori.shape[0:2] + # if larger than 800, scale it + scale = max(h / 800, w / 800) + if scale > 1: + img = cv2.resize(img_ori, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_LINEAR) + + with torch.no_grad(): + bboxes = det_net.detect_faces(img, 0.97) + if scale > 1: + bboxes *= scale # the score is incorrect + bboxes = get_largest_face(bboxes, h, w)[0] + + landmarks = np.array([[bboxes[i], bboxes[i + 1]] for i in range(5, 15, 2)]) + + cropped_face, inverse_affine = align_crop_face_landmarks( + img_ori, + landmarks, + output_size=512, + transform_size=None, + enable_padding=True, + return_inverse_affine=True, + shrink_ratio=(1, 1), + ) + + cv2.imwrite(f"tmp/{img_name}_cropeed_face.png", cropped_face) + img = paste_face_back(img_ori, cropped_face, inverse_affine) + cv2.imwrite(f"tmp/{img_name}_back.png", img) diff --git a/hordelib/nodes/facerestore/facelib/utils/misc.py b/hordelib/nodes/facerestore_cf/facelib/utils/misc.py similarity index 70% rename from hordelib/nodes/facerestore/facelib/utils/misc.py rename to hordelib/nodes/facerestore_cf/facelib/utils/misc.py index 68b7fe9d..0eee9f24 100644 --- a/hordelib/nodes/facerestore/facelib/utils/misc.py +++ b/hordelib/nodes/facerestore_cf/facelib/utils/misc.py @@ -1,143 +1,132 @@ -import cv2 -import os -import os.path as osp -import torch -from torch.hub import download_url_to_file, get_dir -from urllib.parse import urlparse -# from basicsr.utils.download_util import download_file_from_google_drive -#import gdown - - -ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - - -def download_pretrained_models(file_ids, save_path_root): - os.makedirs(save_path_root, exist_ok=True) - - for file_name, file_id in file_ids.items(): - file_url = 'https://drive.google.com/uc?id='+file_id - save_path = osp.abspath(osp.join(save_path_root, file_name)) - if osp.exists(save_path): - user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n') - if user_response.lower() == 'y': - print(f'Covering {file_name} to {save_path}') - print("skipping gdown in facelib/utils/misc.py "+file_url) - #gdown.download(file_url, save_path, quiet=False) - # download_file_from_google_drive(file_id, save_path) - elif user_response.lower() == 'n': - print(f'Skipping {file_name}') - else: - raise ValueError('Wrong input. Only accepts Y/N.') - else: - print(f'Downloading {file_name} to {save_path}') - print("skipping gdown in facelib/utils/misc.py "+file_url) - #gdown.download(file_url, save_path, quiet=False) - # download_file_from_google_drive(file_id, save_path) - - -def imwrite(img, file_path, params=None, auto_mkdir=True): - """Write image to file. - - Args: - img (ndarray): Image array to be written. - file_path (str): Image file path. - params (None or list): Same as opencv's :func:`imwrite` interface. - auto_mkdir (bool): If the parent folder of `file_path` does not exist, - whether to create it automatically. - - Returns: - bool: Successful or not. - """ - if auto_mkdir: - dir_name = os.path.abspath(os.path.dirname(file_path)) - os.makedirs(dir_name, exist_ok=True) - return cv2.imwrite(file_path, img, params) - - -def img2tensor(imgs, bgr2rgb=True, float32=True): - """Numpy array to tensor. - - Args: - imgs (list[ndarray] | ndarray): Input images. - bgr2rgb (bool): Whether to change bgr to rgb. - float32 (bool): Whether to change to float32. - - Returns: - list[tensor] | tensor: Tensor images. If returned results only have - one element, just return tensor. - """ - - def _totensor(img, bgr2rgb, float32): - if img.shape[2] == 3 and bgr2rgb: - if img.dtype == 'float64': - img = img.astype('float32') - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - img = torch.from_numpy(img.transpose(2, 0, 1)) - if float32: - img = img.float() - return img - - if isinstance(imgs, list): - return [_totensor(img, bgr2rgb, float32) for img in imgs] - else: - return _totensor(imgs, bgr2rgb, float32) - - -def load_file_from_url(url, model_dir=None, progress=True, file_name=None): - """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py - """ - if model_dir is None: - hub_dir = get_dir() - model_dir = os.path.join(hub_dir, 'checkpoints') - - os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True) - - parts = urlparse(url) - filename = os.path.basename(parts.path) - if file_name is not None: - filename = file_name - cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename)) - if not os.path.exists(cached_file): - print(f'Downloading: "{url}" to {cached_file}\n') - download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) - return cached_file - - -def scandir(dir_path, suffix=None, recursive=False, full_path=False): - """Scan a directory to find the interested files. - Args: - dir_path (str): Path of the directory. - suffix (str | tuple(str), optional): File suffix that we are - interested in. Default: None. - recursive (bool, optional): If set to True, recursively scan the - directory. Default: False. - full_path (bool, optional): If set to True, include the dir_path. - Default: False. - Returns: - A generator for all the interested files with relative paths. - """ - - if (suffix is not None) and not isinstance(suffix, (str, tuple)): - raise TypeError('"suffix" must be a string or tuple of strings') - - root = dir_path - - def _scandir(dir_path, suffix, recursive): - for entry in os.scandir(dir_path): - if not entry.name.startswith('.') and entry.is_file(): - if full_path: - return_path = entry.path - else: - return_path = osp.relpath(entry.path, root) - - if suffix is None: - yield return_path - elif return_path.endswith(suffix): - yield return_path - else: - if recursive: - yield from _scandir(entry.path, suffix=suffix, recursive=recursive) - else: - continue - - return _scandir(dir_path, suffix=suffix, recursive=recursive) +import os +import os.path as osp +from urllib.parse import urlparse + +import cv2 +import folder_paths +import torch +from torch.hub import download_url_to_file, get_dir +from hordelib.shared_model_manager import SharedModelManager + +# from hordelib.nodes.facerestore.basicsr.utils.download_util import download_file_from_google_drive +# import gdown + + +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +def download_pretrained_models(file_ids, save_path_root): + os.makedirs(save_path_root, exist_ok=True) + + for file_name, file_id in file_ids.items(): + file_url = "https://drive.google.com/uc?id=" + file_id + save_path = osp.abspath(osp.join(save_path_root, file_name)) + if osp.exists(save_path): + user_response = input(f"{file_name} already exist. Do you want to cover it? Y/N\n") + if user_response.lower() == "y": + print(f"Covering {file_name} to {save_path}") + print("skipping gdown in facelib/utils/misc.py " + file_url) + # gdown.download(file_url, save_path, quiet=False) + # download_file_from_google_drive(file_id, save_path) + elif user_response.lower() == "n": + print(f"Skipping {file_name}") + else: + raise ValueError("Wrong input. Only accepts Y/N.") + else: + print(f"Downloading {file_name} to {save_path}") + print("skipping gdown in facelib/utils/misc.py " + file_url) + # gdown.download(file_url, save_path, quiet=False) + # download_file_from_google_drive(file_id, save_path) + + +def imwrite(img, file_path, params=None, auto_mkdir=True): + """Write image to file. + + Args: + img (ndarray): Image array to be written. + file_path (str): Image file path. + params (None or list): Same as opencv's :func:`imwrite` interface. + auto_mkdir (bool): If the parent folder of `file_path` does not exist, + whether to create it automatically. + + Returns: + bool: Successful or not. + """ + if auto_mkdir: + dir_name = os.path.abspath(os.path.dirname(file_path)) + os.makedirs(dir_name, exist_ok=True) + return cv2.imwrite(file_path, img, params) + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == "float64": + img = img.astype("float32") + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def load_file_from_url(url, model_dir=None, progress=True, file_name=None): + """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py""" + return str(SharedModelManager.manager.gfpgan.model_folder_path / file_name) + + +def scandir(dir_path, suffix=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + Returns: + A generator for all the interested files with relative paths. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith(".") and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + else: + if recursive: + yield from _scandir(entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) diff --git a/hordelib/nodes/facerestore_cf/r_chainner/README.md b/hordelib/nodes/facerestore_cf/r_chainner/README.md new file mode 100644 index 00000000..6b254d3f --- /dev/null +++ b/hordelib/nodes/facerestore_cf/r_chainner/README.md @@ -0,0 +1,3 @@ +Clean implementation for GFPGAN loading copied from [this commit](https://github.com/Gourieff/comfyui-reactor-node/commit/a7ae66912f80e8ccd97bb83bf83ab8187b077287#diff-b668993a9f6df352129e883337a4f2c96b31ab61afd82c4ae948d40864962c12) to solve [this issue](https://github.com/mav-rik/facerestore_cf) + +Solution Discovered in this issue: https://github.com/comfyanonymous/ComfyUI/issues/3594 diff --git a/hordelib/nodes/facerestore_cf/r_chainner/__init__.py b/hordelib/nodes/facerestore_cf/r_chainner/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/hordelib/nodes/facerestore_cf/r_chainner/gfpganv1_clean_arch.py b/hordelib/nodes/facerestore_cf/r_chainner/gfpganv1_clean_arch.py new file mode 100644 index 00000000..7f2f0e75 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/r_chainner/gfpganv1_clean_arch.py @@ -0,0 +1,370 @@ +# pylint: skip-file +# type: ignore +import math +import random + +import torch +from torch import nn +from torch.nn import functional as F + +from hordelib.nodes.facerestore_cf.r_chainner.stylegan2_clean_arch import StyleGAN2GeneratorClean + + +class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean): + """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform). + It is the clean version without custom compiled CUDA extensions used in StyleGAN2. + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + num_mlp (int): Layer number of MLP style layers. Default: 8. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + narrow (float): The narrow ratio for channels. Default: 1. + sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. + """ + + def __init__( + self, + out_size, + num_style_feat=512, + num_mlp=8, + channel_multiplier=2, + narrow=1, + sft_half=False, + ): + super(StyleGAN2GeneratorCSFT, self).__init__( + out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + narrow=narrow, + ) + self.sft_half = sft_half + + def forward( + self, + styles, + conditions, + input_is_latent=False, + noise=None, + randomize_noise=True, + truncation=1, + truncation_latent=None, + inject_index=None, + return_latents=False, + ): + """Forward function for StyleGAN2GeneratorCSFT. + Args: + styles (list[Tensor]): Sample codes of styles. + conditions (list[Tensor]): SFT conditions to generators. + input_is_latent (bool): Whether input is latent style. Default: False. + noise (Tensor | None): Input noise or None. Default: None. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + truncation (float): The truncation ratio. Default: 1. + truncation_latent (Tensor | None): The truncation latent tensor. Default: None. + inject_index (int | None): The injection index for mixing noise. Default: None. + return_latents (bool): Whether to return style latents. Default: False. + """ + # style codes -> latents with Style MLP layer + if not input_is_latent: + styles = [self.style_mlp(s) for s in styles] + # noises + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers # for each style conv layer + else: # use the stored noise + noise = [ + getattr(self.noises, f"noise{i}") for i in range(self.num_layers) + ] + # style truncation + if truncation < 1: + style_truncation = [] + for style in styles: + style_truncation.append( + truncation_latent + truncation * (style - truncation_latent) + ) + styles = style_truncation + # get style latents with injection + if len(styles) == 1: + inject_index = self.num_latent + + if styles[0].ndim < 3: + # repeat latent code for all the layers + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: # used for encoder with different latent code for each layer + latent = styles[0] + elif len(styles) == 2: # mixing noises + if inject_index is None: + inject_index = random.randint(1, self.num_latent - 1) + latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = ( + styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) + ) + latent = torch.cat([latent1, latent2], 1) + + # main generation + out = self.constant_input(latent.shape[0]) + out = self.style_conv1(out, latent[:, 0], noise=noise[0]) + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip( + self.style_convs[::2], + self.style_convs[1::2], + noise[1::2], + noise[2::2], + self.to_rgbs, + ): + out = conv1(out, latent[:, i], noise=noise1) + + # the conditions may have fewer levels + if i < len(conditions): + # SFT part to combine the conditions + if self.sft_half: # only apply SFT to half of the channels + out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1) + out_sft = out_sft * conditions[i - 1] + conditions[i] + out = torch.cat([out_same, out_sft], dim=1) + else: # apply SFT to all the channels + out = out * conditions[i - 1] + conditions[i] + + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space + i += 2 + + image = skip + + if return_latents: + return image, latent + else: + return image, None + + +class ResBlock(nn.Module): + """Residual block with bilinear upsampling/downsampling. + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + mode (str): Upsampling/downsampling mode. Options: down | up. Default: down. + """ + + def __init__(self, in_channels, out_channels, mode="down"): + super(ResBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1) + self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1) + self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False) + if mode == "down": + self.scale_factor = 0.5 + elif mode == "up": + self.scale_factor = 2 + + def forward(self, x): + out = F.leaky_relu_(self.conv1(x), negative_slope=0.2) + # upsample/downsample + out = F.interpolate( + out, scale_factor=self.scale_factor, mode="bilinear", align_corners=False + ) + out = F.leaky_relu_(self.conv2(out), negative_slope=0.2) + # skip + x = F.interpolate( + x, scale_factor=self.scale_factor, mode="bilinear", align_corners=False + ) + skip = self.skip(x) + out = out + skip + return out + + +class GFPGANv1Clean(nn.Module): + """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT. + It is the clean version without custom compiled CUDA extensions used in StyleGAN2. + Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior. + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None. + fix_decoder (bool): Whether to fix the decoder. Default: True. + num_mlp (int): Layer number of MLP style layers. Default: 8. + input_is_latent (bool): Whether input is latent style. Default: False. + different_w (bool): Whether to use different latent w for different layers. Default: False. + narrow (float): The narrow ratio for channels. Default: 1. + sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. + """ + + def __init__( + self, + state_dict, + ): + super(GFPGANv1Clean, self).__init__() + + out_size = 512 + num_style_feat = 512 + channel_multiplier = 2 + decoder_load_path = None + fix_decoder = False + num_mlp = 8 + input_is_latent = True + different_w = True + narrow = 1 + sft_half = True + + self.model_arch = "GFPGAN" + self.sub_type = "Face SR" + self.scale = 8 + self.in_nc = 3 + self.out_nc = 3 + self.state = state_dict + + self.supports_fp16 = False + self.supports_bf16 = True + self.min_size_restriction = 512 + + self.input_is_latent = input_is_latent + self.different_w = different_w + self.num_style_feat = num_style_feat + + unet_narrow = narrow * 0.5 # by default, use a half of input channels + channels = { + "4": int(512 * unet_narrow), + "8": int(512 * unet_narrow), + "16": int(512 * unet_narrow), + "32": int(512 * unet_narrow), + "64": int(256 * channel_multiplier * unet_narrow), + "128": int(128 * channel_multiplier * unet_narrow), + "256": int(64 * channel_multiplier * unet_narrow), + "512": int(32 * channel_multiplier * unet_narrow), + "1024": int(16 * channel_multiplier * unet_narrow), + } + + self.log_size = int(math.log(out_size, 2)) + first_out_size = 2 ** (int(math.log(out_size, 2))) + + self.conv_body_first = nn.Conv2d(3, channels[f"{first_out_size}"], 1) + + # downsample + in_channels = channels[f"{first_out_size}"] + self.conv_body_down = nn.ModuleList() + for i in range(self.log_size, 2, -1): + out_channels = channels[f"{2**(i - 1)}"] + self.conv_body_down.append(ResBlock(in_channels, out_channels, mode="down")) + in_channels = out_channels + + self.final_conv = nn.Conv2d(in_channels, channels["4"], 3, 1, 1) + + # upsample + in_channels = channels["4"] + self.conv_body_up = nn.ModuleList() + for i in range(3, self.log_size + 1): + out_channels = channels[f"{2**i}"] + self.conv_body_up.append(ResBlock(in_channels, out_channels, mode="up")) + in_channels = out_channels + + # to RGB + self.toRGB = nn.ModuleList() + for i in range(3, self.log_size + 1): + self.toRGB.append(nn.Conv2d(channels[f"{2**i}"], 3, 1)) + + if different_w: + linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat + else: + linear_out_channel = num_style_feat + + self.final_linear = nn.Linear(channels["4"] * 4 * 4, linear_out_channel) + + # the decoder: stylegan2 generator with SFT modulations + self.stylegan_decoder = StyleGAN2GeneratorCSFT( + out_size=out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + narrow=narrow, + sft_half=sft_half, + ) + + # load pre-trained stylegan2 model if necessary + if decoder_load_path: + self.stylegan_decoder.load_state_dict( + torch.load( + decoder_load_path, map_location=lambda storage, loc: storage + )["params_ema"] + ) + # fix decoder without updating params + if fix_decoder: + for _, param in self.stylegan_decoder.named_parameters(): + param.requires_grad = False + + # for SFT modulations (scale and shift) + self.condition_scale = nn.ModuleList() + self.condition_shift = nn.ModuleList() + for i in range(3, self.log_size + 1): + out_channels = channels[f"{2**i}"] + if sft_half: + sft_out_channels = out_channels + else: + sft_out_channels = out_channels * 2 + self.condition_scale.append( + nn.Sequential( + nn.Conv2d(out_channels, out_channels, 3, 1, 1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1), + ) + ) + self.condition_shift.append( + nn.Sequential( + nn.Conv2d(out_channels, out_channels, 3, 1, 1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1), + ) + ) + self.load_state_dict(state_dict) + + def forward( + self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs + ): + """Forward function for GFPGANv1Clean. + Args: + x (Tensor): Input images. + return_latents (bool): Whether to return style latents. Default: False. + return_rgb (bool): Whether return intermediate rgb images. Default: True. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + """ + conditions = [] + unet_skips = [] + out_rgbs = [] + + # encoder + feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2) + for i in range(self.log_size - 2): + feat = self.conv_body_down[i](feat) + unet_skips.insert(0, feat) + feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2) + + # style code + style_code = self.final_linear(feat.view(feat.size(0), -1)) + if self.different_w: + style_code = style_code.view(style_code.size(0), -1, self.num_style_feat) + + # decode + for i in range(self.log_size - 2): + # add unet skip + feat = feat + unet_skips[i] + # ResUpLayer + feat = self.conv_body_up[i](feat) + # generate scale and shift for SFT layers + scale = self.condition_scale[i](feat) + conditions.append(scale.clone()) + shift = self.condition_shift[i](feat) + conditions.append(shift.clone()) + # generate rgb images + if return_rgb: + out_rgbs.append(self.toRGB[i](feat)) + + # decoder + image, _ = self.stylegan_decoder( + [style_code], + conditions, + return_latents=return_latents, + input_is_latent=self.input_is_latent, + randomize_noise=randomize_noise, + ) + + return image, out_rgbs diff --git a/hordelib/nodes/facerestore_cf/r_chainner/model_loading.py b/hordelib/nodes/facerestore_cf/r_chainner/model_loading.py new file mode 100644 index 00000000..598e605c --- /dev/null +++ b/hordelib/nodes/facerestore_cf/r_chainner/model_loading.py @@ -0,0 +1,29 @@ + +from hordelib.nodes.facerestore_cf.r_chainner.gfpganv1_clean_arch import GFPGANv1Clean +from hordelib.nodes.facerestore_cf.r_chainner.types import PyTorchModel + + +class UnsupportedModel(Exception): + pass + + +def load_state_dict(state_dict) -> PyTorchModel: + + state_dict_keys = list(state_dict.keys()) + + if "params_ema" in state_dict_keys: + state_dict = state_dict["params_ema"] + elif "params-ema" in state_dict_keys: + state_dict = state_dict["params-ema"] + elif "params" in state_dict_keys: + state_dict = state_dict["params"] + + state_dict_keys = list(state_dict.keys()) + + # GFPGAN + if ( + "toRGB.0.weight" in state_dict_keys + and "stylegan_decoder.style_mlp.1.weight" in state_dict_keys + ): + model = GFPGANv1Clean(state_dict) + return model diff --git a/hordelib/nodes/facerestore_cf/r_chainner/stylegan2_clean_arch.py b/hordelib/nodes/facerestore_cf/r_chainner/stylegan2_clean_arch.py new file mode 100644 index 00000000..c48de9af --- /dev/null +++ b/hordelib/nodes/facerestore_cf/r_chainner/stylegan2_clean_arch.py @@ -0,0 +1,453 @@ +# pylint: skip-file +# type: ignore +import math + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn import init +from torch.nn.modules.batchnorm import _BatchNorm + + +@torch.no_grad() +def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): + """Initialize network weights. + Args: + module_list (list[nn.Module] | nn.Module): Modules to be initialized. + scale (float): Scale initialized weights, especially for residual + blocks. Default: 1. + bias_fill (float): The value to fill bias. Default: 0 + kwargs (dict): Other arguments for initialization function. + """ + if not isinstance(module_list, list): + module_list = [module_list] + for module in module_list: + for m in module.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, _BatchNorm): + init.constant_(m.weight, 1) + if m.bias is not None: + m.bias.data.fill_(bias_fill) + + +class NormStyleCode(nn.Module): + def forward(self, x): + """Normalize the style codes. + Args: + x (Tensor): Style codes with shape (b, c). + Returns: + Tensor: Normalized tensor. + """ + return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8) + + +class ModulatedConv2d(nn.Module): + """Modulated Conv2d used in StyleGAN2. + There is no bias in ModulatedConv2d. + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + num_style_feat (int): Channel number of style features. + demodulate (bool): Whether to demodulate in the conv layer. Default: True. + sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None. + eps (float): A value added to the denominator for numerical stability. Default: 1e-8. + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=True, + sample_mode=None, + eps=1e-8, + ): + super(ModulatedConv2d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.demodulate = demodulate + self.sample_mode = sample_mode + self.eps = eps + + # modulation inside each modulated conv + self.modulation = nn.Linear(num_style_feat, in_channels, bias=True) + # initialization + default_init_weights( + self.modulation, + scale=1, + bias_fill=1, + a=0, + mode="fan_in", + nonlinearity="linear", + ) + + self.weight = nn.Parameter( + torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) + / math.sqrt(in_channels * kernel_size**2) + ) + self.padding = kernel_size // 2 + + def forward(self, x, style): + """Forward function. + Args: + x (Tensor): Tensor with shape (b, c, h, w). + style (Tensor): Tensor with shape (b, num_style_feat). + Returns: + Tensor: Modulated tensor after convolution. + """ + b, c, h, w = x.shape # c = c_in + # weight modulation + style = self.modulation(style).view(b, 1, c, 1, 1) + # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1) + weight = self.weight * style # (b, c_out, c_in, k, k) + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps) + weight = weight * demod.view(b, self.out_channels, 1, 1, 1) + + weight = weight.view( + b * self.out_channels, c, self.kernel_size, self.kernel_size + ) + + # upsample or downsample if necessary + if self.sample_mode == "upsample": + x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False) + elif self.sample_mode == "downsample": + x = F.interpolate(x, scale_factor=0.5, mode="bilinear", align_corners=False) + + b, c, h, w = x.shape + x = x.view(1, b * c, h, w) + # weight: (b*c_out, c_in, k, k), groups=b + out = F.conv2d(x, weight, padding=self.padding, groups=b) + out = out.view(b, self.out_channels, *out.shape[2:4]) + + return out + + def __repr__(self): + return ( + f"{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, " + f"kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})" + ) + + +class StyleConv(nn.Module): + """Style conv used in StyleGAN2. + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + num_style_feat (int): Channel number of style features. + demodulate (bool): Whether demodulate in the conv layer. Default: True. + sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None. + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=True, + sample_mode=None, + ): + super(StyleConv, self).__init__() + self.modulated_conv = ModulatedConv2d( + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=demodulate, + sample_mode=sample_mode, + ) + self.weight = nn.Parameter(torch.zeros(1)) # for noise injection + self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1)) + self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x, style, noise=None): + # modulate + out = self.modulated_conv(x, style) * 2**0.5 # for conversion + # noise injection + if noise is None: + b, _, h, w = out.shape + noise = out.new_empty(b, 1, h, w).normal_() + out = out + self.weight * noise + # add bias + out = out + self.bias + # activation + out = self.activate(out) + return out + + +class ToRGB(nn.Module): + """To RGB (image space) from features. + Args: + in_channels (int): Channel number of input. + num_style_feat (int): Channel number of style features. + upsample (bool): Whether to upsample. Default: True. + """ + + def __init__(self, in_channels, num_style_feat, upsample=True): + super(ToRGB, self).__init__() + self.upsample = upsample + self.modulated_conv = ModulatedConv2d( + in_channels, + 3, + kernel_size=1, + num_style_feat=num_style_feat, + demodulate=False, + sample_mode=None, + ) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, x, style, skip=None): + """Forward function. + Args: + x (Tensor): Feature tensor with shape (b, c, h, w). + style (Tensor): Tensor with shape (b, num_style_feat). + skip (Tensor): Base/skip tensor. Default: None. + Returns: + Tensor: RGB images. + """ + out = self.modulated_conv(x, style) + out = out + self.bias + if skip is not None: + if self.upsample: + skip = F.interpolate( + skip, scale_factor=2, mode="bilinear", align_corners=False + ) + out = out + skip + return out + + +class ConstantInput(nn.Module): + """Constant input. + Args: + num_channel (int): Channel number of constant input. + size (int): Spatial size of constant input. + """ + + def __init__(self, num_channel, size): + super(ConstantInput, self).__init__() + self.weight = nn.Parameter(torch.randn(1, num_channel, size, size)) + + def forward(self, batch): + out = self.weight.repeat(batch, 1, 1, 1) + return out + + +class StyleGAN2GeneratorClean(nn.Module): + """Clean version of StyleGAN2 Generator. + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + num_mlp (int): Layer number of MLP style layers. Default: 8. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + narrow (float): Narrow ratio for channels. Default: 1.0. + """ + + def __init__( + self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1 + ): + super(StyleGAN2GeneratorClean, self).__init__() + # Style MLP layers + self.num_style_feat = num_style_feat + style_mlp_layers = [NormStyleCode()] + for i in range(num_mlp): + style_mlp_layers.extend( + [ + nn.Linear(num_style_feat, num_style_feat, bias=True), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ] + ) + self.style_mlp = nn.Sequential(*style_mlp_layers) + # initialization + default_init_weights( + self.style_mlp, + scale=1, + bias_fill=0, + a=0.2, + mode="fan_in", + nonlinearity="leaky_relu", + ) + + # channel list + channels = { + "4": int(512 * narrow), + "8": int(512 * narrow), + "16": int(512 * narrow), + "32": int(512 * narrow), + "64": int(256 * channel_multiplier * narrow), + "128": int(128 * channel_multiplier * narrow), + "256": int(64 * channel_multiplier * narrow), + "512": int(32 * channel_multiplier * narrow), + "1024": int(16 * channel_multiplier * narrow), + } + self.channels = channels + + self.constant_input = ConstantInput(channels["4"], size=4) + self.style_conv1 = StyleConv( + channels["4"], + channels["4"], + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode=None, + ) + self.to_rgb1 = ToRGB(channels["4"], num_style_feat, upsample=False) + + self.log_size = int(math.log(out_size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + self.num_latent = self.log_size * 2 - 2 + + self.style_convs = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + in_channels = channels["4"] + # noise + for layer_idx in range(self.num_layers): + resolution = 2 ** ((layer_idx + 5) // 2) + shape = [1, 1, resolution, resolution] + self.noises.register_buffer(f"noise{layer_idx}", torch.randn(*shape)) + # style convs and to_rgbs + for i in range(3, self.log_size + 1): + out_channels = channels[f"{2**i}"] + self.style_convs.append( + StyleConv( + in_channels, + out_channels, + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode="upsample", + ) + ) + self.style_convs.append( + StyleConv( + out_channels, + out_channels, + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode=None, + ) + ) + self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True)) + in_channels = out_channels + + def make_noise(self): + """Make noise for noise injection.""" + device = self.constant_input.weight.device + noises = [torch.randn(1, 1, 4, 4, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2**i, 2**i, device=device)) + + return noises + + def get_latent(self, x): + return self.style_mlp(x) + + def mean_latent(self, num_latent): + latent_in = torch.randn( + num_latent, self.num_style_feat, device=self.constant_input.weight.device + ) + latent = self.style_mlp(latent_in).mean(0, keepdim=True) + return latent + + def forward( + self, + styles, + input_is_latent=False, + noise=None, + randomize_noise=True, + truncation=1, + truncation_latent=None, + inject_index=None, + return_latents=False, + ): + """Forward function for StyleGAN2GeneratorClean. + Args: + styles (list[Tensor]): Sample codes of styles. + input_is_latent (bool): Whether input is latent style. Default: False. + noise (Tensor | None): Input noise or None. Default: None. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + truncation (float): The truncation ratio. Default: 1. + truncation_latent (Tensor | None): The truncation latent tensor. Default: None. + inject_index (int | None): The injection index for mixing noise. Default: None. + return_latents (bool): Whether to return style latents. Default: False. + """ + # style codes -> latents with Style MLP layer + if not input_is_latent: + styles = [self.style_mlp(s) for s in styles] + # noises + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers # for each style conv layer + else: # use the stored noise + noise = [ + getattr(self.noises, f"noise{i}") for i in range(self.num_layers) + ] + # style truncation + if truncation < 1: + style_truncation = [] + for style in styles: + style_truncation.append( + truncation_latent + truncation * (style - truncation_latent) + ) + styles = style_truncation + # get style latents with injection + if len(styles) == 1: + inject_index = self.num_latent + + if styles[0].ndim < 3: + # repeat latent code for all the layers + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: # used for encoder with different latent code for each layer + latent = styles[0] + elif len(styles) == 2: # mixing noises + if inject_index is None: + inject_index = random.randint(1, self.num_latent - 1) + latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = ( + styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) + ) + latent = torch.cat([latent1, latent2], 1) + + # main generation + out = self.constant_input(latent.shape[0]) + out = self.style_conv1(out, latent[:, 0], noise=noise[0]) + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip( + self.style_convs[::2], + self.style_convs[1::2], + noise[1::2], + noise[2::2], + self.to_rgbs, + ): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space + i += 2 + + image = skip + + if return_latents: + return image, latent + else: + return image, None diff --git a/hordelib/nodes/facerestore_cf/r_chainner/types.py b/hordelib/nodes/facerestore_cf/r_chainner/types.py new file mode 100644 index 00000000..20e39f68 --- /dev/null +++ b/hordelib/nodes/facerestore_cf/r_chainner/types.py @@ -0,0 +1,19 @@ + +from typing import Union + +from hordelib.nodes.facerestore_cf.r_chainner.gfpganv1_clean_arch import GFPGANv1Clean + + +PyTorchFaceModels = (GFPGANv1Clean,) +PyTorchFaceModel = Union[GFPGANv1Clean] + + +def is_pytorch_face_model(model: object): + return isinstance(model, PyTorchFaceModels) + +PyTorchModels = (*PyTorchFaceModels, ) +PyTorchModel = Union[PyTorchFaceModel] + + +def is_pytorch_model(model: object): + return isinstance(model, PyTorchModels) diff --git a/hordelib/nodes/node_model_loader.py b/hordelib/nodes/node_model_loader.py index 4ca17900..0105d019 100644 --- a/hordelib/nodes/node_model_loader.py +++ b/hordelib/nodes/node_model_loader.py @@ -9,8 +9,8 @@ import torch from loguru import logger -from hordelib.shared_model_manager import SharedModelManager from hordelib.comfy_horde import log_free_ram +from hordelib.shared_model_manager import SharedModelManager # Don't let the name fool you, this class is trying to load all the files that will be necessary diff --git a/hordelib/pipeline_designs/pipeline_image_facefix.json b/hordelib/pipeline_designs/pipeline_image_facefix.json index 19b15fbe..42cc8630 100644 --- a/hordelib/pipeline_designs/pipeline_image_facefix.json +++ b/hordelib/pipeline_designs/pipeline_image_facefix.json @@ -1,6 +1,6 @@ { - "last_node_id": 8, - "last_link_id": 8, + "last_node_id": 11, + "last_link_id": 14, "nodes": [ { "id": 6, @@ -9,10 +9,10 @@ 771, 331 ], - "size": [ - 427.0001220703125, - 416.33331298828125 - ], + "size": { + "0": 427.0001220703125, + "1": 416.33331298828125 + }, "flags": {}, "order": 3, "mode": 0, @@ -20,7 +20,7 @@ { "name": "images", "type": "IMAGE", - "link": 7 + "link": 14 } ], "title": "output_image", @@ -30,11 +30,11 @@ ] }, { - "id": 4, - "type": "UpscaleModelLoader", + "id": 10, + "type": "FaceRestoreModelLoader", "pos": [ - 24, - 73 + -5, + 145 ], "size": { "0": 315, @@ -45,133 +45,144 @@ "mode": 0, "outputs": [ { - "name": "UPSCALE_MODEL", - "type": "UPSCALE_MODEL", + "name": "FACERESTORE_MODEL", + "type": "FACERESTORE_MODEL", "links": [ - 6 + 12 ], + "shape": 3, "slot_index": 0 } ], "title": "model_loader", "properties": { - "Node name for S&R": "UpscaleModelLoader" + "Node name for S&R": "FaceRestoreModelLoader" }, "widgets_values": [ "CodeFormers.pth" ] }, { - "id": 8, - "type": "FaceRestoreWithModel", - "title": "face_restore_with_model", + "id": 1, + "type": "LoadImage", "pos": [ - 390, - 194 + 30, + 325 + ], + "size": [ + 315, + 314 ], - "size": { - "0": 315, - "1": 78 - }, "flags": {}, - "order": 2, + "order": 1, "mode": 0, - "inputs": [ - { - "name": "upscale_model", - "type": "UPSCALE_MODEL", - "link": 6 - }, - { - "name": "image", - "type": "IMAGE", - "link": 8 - } - ], "outputs": [ { "name": "IMAGE", "type": "IMAGE", "links": [ - 7 + 13 ], "slot_index": 0 + }, + { + "name": "MASK", + "type": "MASK", + "links": null } ], + "title": "image_loader", "properties": { - "Node name for S&R": "FaceRestoreWithModel" + "Node name for S&R": "LoadImage" }, "widgets_values": [ - "retinaface_resnet50" + "test_facefix.png", + "image" ] }, { - "id": 1, - "type": "LoadImage", + "id": 11, + "type": "FaceRestoreCFWithModel", "pos": [ - 30, - 325 + 397, + 166 ], "size": { "0": 315, "1": 102 }, "flags": {}, - "order": 1, + "order": 2, "mode": 0, + "inputs": [ + { + "name": "facerestore_model", + "type": "FACERESTORE_MODEL", + "link": 12 + }, + { + "name": "image", + "type": "IMAGE", + "link": 13 + } + ], "outputs": [ { "name": "IMAGE", "type": "IMAGE", "links": [ - 8 + 14 ], + "shape": 3, "slot_index": 0 - }, - { - "name": "MASK", - "type": "MASK", - "links": null } ], - "title": "image_loader", + "title": "face_restore_with_model", "properties": { - "Node name for S&R": "LoadImage" + "Node name for S&R": "FaceRestoreCFWithModel" }, "widgets_values": [ - "test_facefix.png", - "image" + "retinaface_resnet50", + 0.5 ] } ], "links": [ [ - 6, - 4, + 12, + 10, 0, - 8, + 11, 0, - "UPSCALE_MODEL" + "FACERESTORE_MODEL" ], [ - 7, - 8, - 0, - 6, + 13, + 1, 0, + 11, + 1, "IMAGE" ], [ - 8, - 1, + 14, + 11, + 0, + 6, 0, - 8, - 1, "IMAGE" ] ], "groups": [], "config": {}, - "extra": {}, + "extra": { + "ds": { + "scale": 1, + "offset": [ + 383.39140502393127, + 89.69757412917136 + ] + } + }, "version": 0.4 } diff --git a/hordelib/pipelines/pipeline_image_facefix.json b/hordelib/pipelines/pipeline_image_facefix.json index a00f479a..b06dcfae 100644 --- a/hordelib/pipelines/pipeline_image_facefix.json +++ b/hordelib/pipelines/pipeline_image_facefix.json @@ -9,20 +9,11 @@ "title": "image_loader" } }, - "4": { - "inputs": { - "model_name": "CodeFormers.pth" - }, - "class_type": "UpscaleModelLoader", - "_meta": { - "title": "model_loader" - } - }, "6": { "inputs": { "filename_prefix": "ComfyUI", "images": [ - "8", + "11", 0 ] }, @@ -31,11 +22,21 @@ "title": "output_image" } }, - "8": { + "10": { + "inputs": { + "model_name": "CodeFormers.pth" + }, + "class_type": "FaceRestoreModelLoader", + "_meta": { + "title": "model_loader" + } + }, + "11": { "inputs": { "facedetection": "retinaface_resnet50", - "upscale_model": [ - "4", + "codeformer_fidelity": 0.5, + "facerestore_model": [ + "10", 0 ], "image": [ @@ -43,9 +44,9 @@ 0 ] }, + "class_type": "FaceRestoreCFWithModel", "_meta": { "title": "face_restore_with_model" - }, - "class_type": "FaceRestoreWithModel" + } } } diff --git a/mypy.ini b/mypy.ini index 10a57f6c..fa0778a7 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,11 +1,15 @@ [mypy] -exclude = (build|dist|ComfyUI|comfy_controlnet_preprocessors|facerestore|comfy_horde\.py|examples|diffusers) +exclude = (build|dist|ComfyUI|comfy_controlnet_preprocessors|facerestore_cf|comfy_horde\.py|examples|diffusers) [mypy-hordelib.nodes.comfy_controlnet_preprocessors.*] ignore_errors = True ignore_missing_imports = True +[mypy-hordelib.nodes.facerestore_cf.*] +ignore_errors = True +ignore_missing_imports = True + [mypy-hordelib.nodes.comfyui_layerdiffuse.*] ignore_errors = True ignore_missing_imports = True diff --git a/pyproject.toml b/pyproject.toml index 4837a77c..7922ba9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,7 @@ exclude = ''' [tool.ruff] # XXX this isn't part of CI yet line-length=119 -exclude=["ComfyUI", "comfy_controlnet_preprocessors", "facerestore", "comfy_qr", "comfyui_layerdiffuse", "build"] +exclude=["ComfyUI", "comfy_controlnet_preprocessors", "facerestore_cf", "comfy_qr", "comfyui_layerdiffuse", "build"] ignore=[ # "F401", # imported but unused "E402", # Module level import not at top of file diff --git a/requirements.txt b/requirements.txt index ab2ca5ec..90cec3c1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -50,3 +50,4 @@ kornia qrcode spandrel spandrel_extra_arches +lpips diff --git a/tests/test_horde_pp.py b/tests/test_horde_pp.py index 72dccc6a..3bb5f442 100644 --- a/tests/test_horde_pp.py +++ b/tests/test_horde_pp.py @@ -66,8 +66,8 @@ def post_processor_check( similarity_constraints = ImageSimilarityConstraints( cosine_fail_floor=CosineSimilarityResultCode.PERCEPTUALLY_IDENTICAL, cosine_warn_floor=CosineSimilarityResultCode.EXTREMELY_SIMILAR, - histogram_fail_threshold=HistogramDistanceResultCode.VERY_DISSIMILAR_DISTRIBUTION, - histogram_warn_threshold=HistogramDistanceResultCode.SIMILAR_DISTRIBUTION, + histogram_fail_threshold=HistogramDistanceResultCode.VERY_SIMILAR_DISTRIBUTION, + histogram_warn_threshold=HistogramDistanceResultCode.EXTREMELY_SIMILAR_DISTRIBUTION, ) assert cls.shared_model_manager.manager.download_model(model_name) assert cls.shared_model_manager.manager.is_model_available(model_name) is True