Skip to content

Commit

Permalink
feat(scripts): adding a video generation script for gans
Browse files Browse the repository at this point in the history
feat(scripts): updated gen_video_gan

feat(video-gan): max-frames arg

feat(scripts): adding --compare and --n-inferences, fixing preprocess and postprocess cvtColor

feat(scripts): --compare option for image gan generation

chore: black

docs(inference): video gan inference
  • Loading branch information
Pierre Pereira authored and beniz committed Nov 9, 2023
1 parent 28ded2d commit 85d1922
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 6 deletions.
42 changes: 42 additions & 0 deletions docs/source/inference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,48 @@ The output file is the ``target.jpg`` image in the current directory:
- - original image given as input to the model
- ``target.jpg``: the output image with the glasses removed


*********************************************
Generate a video with a GAN generator model
*********************************************

The same model will be used as the one presented in :doc:`tutorial_styletransfer_bdd100k`.

Download the video & pretrained model
=====================================

.. code:: bash
wget https://www.joligen.com/models/clear2snowy_bdd100k.zip
unzip clear2snowy_bdd100k.zip -d checkpoints
rm clear2snowy_bdd100k.zip
wget https://www.joligen.com/datasets/vids/051d857c-faeca4ad.mov
Run the inference script
========================

.. code:: bash
cd scripts
python3 gen_video_gan.py\
--model-in-file ../checkpoints/latest_net_G_A.pth\
--video-in ../051d857c-faeca4ad.mov\
--video-out ../snowy-video.avi\
--img-width 1280\
--img-height 720\
--max-frames 2000\
--fps 30\
--gpuid 0
The output file is the ``snowy-video.avi`` video in the parent directory.

You can optionnally use ``--n-inferences`` to apply the model on the frames multiple
times. this would increase the amount of snow generated by the model.

You can also use the ``--compare`` flag to concatenate the generated frames with
the original frames of the video.

******************************************
Generate an image with a diffusion model
******************************************
Expand Down
23 changes: 17 additions & 6 deletions scripts/gen_single_image.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import sys
import os
import json
import os
import sys

sys.path.append("../")
from models import gan_networks
from options.train_options import TrainOptions
import argparse

import cv2
import numpy as np
import torch
from models import gan_networks
from options.train_options import TrainOptions
from torchvision import transforms
from torchvision.utils import save_image
import numpy as np
import argparse


def get_z_random(batch_size=1, nz=8, random_type="gauss"):
Expand Down Expand Up @@ -60,6 +61,11 @@ def load_model(modelpath, model_in_file, cpu, gpuid):
)
parser.add_argument("--cpu", action="store_true", help="whether to use CPU")
parser.add_argument("--gpuid", type=int, default=0, help="which GPU to use")
parser.add_argument(
"--compare",
action="store_true",
help="Concatenate the true image and the transformed image",
)
args = parser.parse_args()

# loading model
Expand All @@ -74,6 +80,7 @@ def load_model(modelpath, model_in_file, cpu, gpuid):
img_width = args.img_width if args.img_width is not None else opt.data_crop_size
img_height = args.img_height if args.img_height is not None else opt.data_crop_size
img = cv2.imread(args.img_in)
original_img = img.copy()
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (img_width, img_height), interpolation=cv2.INTER_CUBIC)

Expand Down Expand Up @@ -110,5 +117,9 @@ def load_model(modelpath, model_in_file, cpu, gpuid):
out_img = (np.transpose(out_img, (1, 2, 0)) + 1) / 2.0 * 255.0
# print(out_img)
out_img = cv2.cvtColor(out_img, cv2.COLOR_RGB2BGR)

if args.compare:
out_img = np.concatenate((original_img, out_img), axis=1)

cv2.imwrite(args.img_out, out_img)
print("Successfully generated image ", args.img_out)
177 changes: 177 additions & 0 deletions scripts/gen_video_gan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import argparse
import json
import os
import sys
from pathlib import Path

import cv2
import numpy as np
import torch
from torchvision import transforms
from torchvision.utils import save_image
from tqdm import tqdm

sys.path.append("../")

from models import gan_networks
from options.train_options import TrainOptions


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()

parser.add_argument(
"--model-in-file",
help="file path to generator model (.pth file)",
type=Path,
required=True,
)
parser.add_argument(
"--video-in", help="video to transform", type=Path, required=True
)
parser.add_argument(
"--video-out", help="transformed video", type=Path, required=True
)
parser.add_argument(
"--img-width", type=int, help="image width, defaults to model crop size"
)
parser.add_argument(
"--img-height", type=int, help="image height, defaults to model crop size"
)
parser.add_argument(
"--max-frames", type=int, help="Select total number of frames to generate"
)
parser.add_argument("--fps", type=int, help="select FPS")
parser.add_argument("--cpu", action="store_true", help="whether to use CPU")
parser.add_argument("--gpuid", type=int, default=0, help="which GPU to use")
parser.add_argument(
"--compare",
action="store_true",
help="put the input video on the left side to compare",
)
parser.add_argument(
"--n-inferences",
type=int,
default=1,
help="Number of recursive inferences per frame",
)
return parser.parse_args()


def get_z_random(
batch_size: int = 1, nz: int = 8, random_type: str = "gauss"
) -> torch.Tensor:
if random_type == "uni":
z = torch.rand(batch_size, nz) * 2.0 - 1.0
elif random_type == "gauss":
z = torch.randn(batch_size, nz)
return z.detach()


def iter_video_frames(video_path: Path, max_frames: int) -> np.ndarray:
"""Iterate over frames in a video."""
cap = cv2.VideoCapture(str(video_path))
max_frames = min(max_frames, int(cap.get(cv2.CAP_PROP_FRAME_COUNT)))

for _ in tqdm(range(max_frames), desc="Processing video frames"):
ret, frame = cap.read()
if not ret:
break
yield frame


def preprocess_frame(
frame: np.ndarray, img_width: int, img_height: int, transforms: transforms.Compose
) -> torch.Tensor:
"""Preprocess a single frame."""
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = cv2.resize(frame, (img_width, img_height), interpolation=cv2.INTER_CUBIC)
frame = transforms(frame)
return frame


def postprocess_frame(frame: torch.Tensor) -> np.ndarray:
frame = frame.detach().cpu().float().numpy()
frame = np.transpose(frame, (1, 2, 0))
frame = (frame + 1) / 2.0 * 255.0
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
frame = frame.astype(np.uint8)
return frame


def load_model(model_dir: Path, model_filename: Path, device: torch.device):
train_json_path = model_dir / "train_config.json"
with open(train_json_path, "r") as jsonf:
train_json = json.load(jsonf)
opt = TrainOptions().parse_json(train_json, set_device=False)
if opt.model_multimodal:
opt.model_input_nc += opt.train_mm_nz
opt.jg_dir = "../"

model = gan_networks.define_G(**vars(opt))
model.eval()
model.load_state_dict(torch.load(model_dir / model_filename, map_location=device))

model = model.to(device)
return model, opt


if __name__ == "__main__":
args = parse_args()

device = torch.device("cpu") if args.cpu else torch.device(f"cuda:{args.gpuid}")

# Load the model.
model_dir = args.model_in_file.parent
print(f"Model directory {model_dir}.")
model, opt = load_model(model_dir, args.model_in_file.name, device)

img_width = args.img_width if args.img_width is not None else opt.data_crop_size
img_height = args.img_height if args.img_height is not None else opt.data_crop_size
transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)

video_width = img_width * 2 if args.compare else img_width
video_height = img_height
video_writer = cv2.VideoWriter(
str(args.video_out),
cv2.VideoWriter_fourcc("M", "J", "P", "G"),
args.fps,
(video_width, video_height),
)

# Optional noise.
# Noise is sampled only once. The same noise is used for all video frames.
if opt.model_multimodal:
z_random = get_z_random(batch_size=1, nz=opt.train_mm_nz)
z_random = z_random.to(device)

with torch.inference_mode():
for frame in iter_video_frames(args.video_in, args.max_frames):
original_frame = frame.copy()
for _ in range(args.n_inferences):
frame = preprocess_frame(frame, img_width, img_height, transforms)
frame = frame.to(device)
frame = frame.unsqueeze(0)

if opt.model_multimodal:
z_real = z_random.view(z_random.size(0), z_random.size(1), 1, 1)
z_real = z_real.expand(
z_random.size(0), z_random.size(1), frame.size(2), frame.size(3)
)
frame = torch.cat((frame, z_real), dim=1)

frame = model(frame)[0]
frame = postprocess_frame(frame)

if args.compare:
frame = np.concatenate((original_frame, frame), axis=1)

video_writer.write(frame)

print(f"Saving video to {args.video_out}.")
video_writer.release()

0 comments on commit 85d1922

Please sign in to comment.