forked from genforce/interfacegan
-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathedit.py
113 lines (98 loc) · 4.86 KB
/
edit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# python3.7
"""Edits latent codes with respect to given boundary.
Basically, this file takes latent codes and a semantic boundary as inputs, and
then shows how the image synthesis will change if the latent codes is moved
towards the given boundary.
"""
import os.path
import argparse
import cv2
import numpy as np
from tqdm import tqdm
from models.model_settings import MODEL_POOL
from models.pggan_generator import PGGANGenerator
from models.stylegan_generator import StyleGANGenerator
from utils.logger import setup_logger
from utils.manipulator import linear_interpolate
def parse_args():
"""Parses arguments."""
parser = argparse.ArgumentParser(
description='Edit image synthesis with given semantic boundary.')
parser.add_argument('-m', '--model_name', type=str, required=True,
choices=list(MODEL_POOL),
help='Name of the model for generation. (required)')
parser.add_argument('-o', '--output_dir', type=str, required=True,
help='Directory to save the output results. (required)')
parser.add_argument('-b', '--boundary_path', type=str, required=True,
help='Path to the semantic boundary. (required)')
parser.add_argument('-i', '--input_latent_codes_path', type=str, default='',
help='If specified, will load latent codes from given '
'path instead of randomly sampling. (optional)')
parser.add_argument('-n', '--num', type=int, default=1,
help='Number of images for editing. This field will be '
'ignored if `input_latent_codes_path` is specified. '
'(default: 1)')
parser.add_argument('-s', '--latent_space_type', type=str, default='z',
choices=['z', 'Z', 'w', 'W', 'wp', 'wP', 'Wp', 'WP'],
help='Latent space used in Style GAN. (default: `Z`)')
parser.add_argument('--start_distance', type=float, default=-3.0,
help='Start point for manipulation in latent space. '
'(default: -3.0)')
parser.add_argument('--end_distance', type=float, default=3.0,
help='End point for manipulation in latent space. '
'(default: 3.0)')
parser.add_argument('--steps', type=int, default=10,
help='Number of steps for image editing. (default: 10)')
return parser.parse_args()
def main():
"""Main function."""
args = parse_args()
logger = setup_logger(args.output_dir, logger_name='generate_data')
logger.info(f'Initializing generator.')
gan_type = MODEL_POOL[args.model_name]['gan_type']
if gan_type == 'pggan':
model = PGGANGenerator(args.model_name, logger)
kwargs = {}
elif gan_type == 'stylegan':
model = StyleGANGenerator(args.model_name, logger)
kwargs = {'latent_space_type': args.latent_space_type}
else:
raise NotImplementedError(f'Not implemented GAN type `{gan_type}`!')
logger.info(f'Preparing boundary.')
if not os.path.isfile(args.boundary_path):
raise ValueError(f'Boundary `{args.boundary_path}` does not exist!')
boundary = np.load(args.boundary_path)
np.save(os.path.join(args.output_dir, 'boundary.npy'), boundary)
logger.info(f'Preparing latent codes.')
if os.path.isfile(args.input_latent_codes_path):
logger.info(f' Load latent codes from `{args.input_latent_codes_path}`.')
latent_codes = np.load(args.input_latent_codes_path)
latent_codes = model.preprocess(latent_codes, **kwargs)
else:
logger.info(f' Sample latent codes randomly.')
latent_codes = model.easy_sample(args.num, **kwargs)
np.save(os.path.join(args.output_dir, 'latent_codes.npy'), latent_codes)
total_num = latent_codes.shape[0]
logger.info(f'Editing {total_num} samples.')
for sample_id in tqdm(range(total_num), leave=False):
interpolations = linear_interpolate(latent_codes[sample_id:sample_id + 1],
boundary,
start_distance=args.start_distance,
end_distance=args.end_distance,
steps=args.steps)
interpolation_id = 0
for interpolations_batch in model.get_batch_inputs(interpolations):
if gan_type == 'pggan':
outputs = model.easy_synthesize(interpolations_batch)
elif gan_type == 'stylegan':
outputs = model.easy_synthesize(interpolations_batch, **kwargs)
for image in outputs['image']:
save_path = os.path.join(args.output_dir,
f'{sample_id:03d}_{interpolation_id:03d}.jpg')
cv2.imwrite(save_path, image[:, :, ::-1])
interpolation_id += 1
assert interpolation_id == args.steps
logger.debug(f' Finished sample {sample_id:3d}.')
logger.info(f'Successfully edited {total_num} samples.')
if __name__ == '__main__':
main()