-
Notifications
You must be signed in to change notification settings - Fork 41
/
train.py
163 lines (143 loc) · 7.02 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import torch
import argparse
import os
from models.anime_gan import GeneratorV1
from models.anime_gan_v2 import GeneratorV2
from models.anime_gan_v3 import GeneratorV3
from models.anime_gan import Discriminator
from datasets import AnimeDataSet
from utils.common import load_checkpoint
from trainer import Trainer
from utils.logger import get_logger
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--real_image_dir', type=str, default='dataset/train_photo')
parser.add_argument('--anime_image_dir', type=str, default='dataset/Hayao')
parser.add_argument('--test_image_dir', type=str, default='dataset/test/HR_photo')
parser.add_argument('--model', type=str, default='v1', help="AnimeGAN version, can be {'v1', 'v2', 'v3'}")
parser.add_argument('--epochs', type=int, default=70)
parser.add_argument('--init_epochs', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--exp_dir', type=str, default='runs', help="Experiment directory")
parser.add_argument('--gan_loss', type=str, default='lsgan', help='lsgan / hinge / bce')
parser.add_argument('--resume', action='store_true', help="Continue from current dir")
parser.add_argument('--resume_G_init', type=str, default='False')
parser.add_argument('--resume_G', type=str, default='False')
parser.add_argument('--resume_D', type=str, default='False')
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--use_sn', action='store_true')
parser.add_argument('--cache', action='store_true', help="Turn on disk cache")
parser.add_argument('--amp', action='store_true', help="Turn on Automatic Mixed Precision")
parser.add_argument('--save_interval', type=int, default=1)
parser.add_argument('--debug_samples', type=int, default=0)
parser.add_argument('--num_workers', type=int, default=2)
parser.add_argument('--imgsz', type=int, nargs="+", default=[256],
help="Image sizes, can provide multiple values, image size will increase after a proportion of epochs")
parser.add_argument('--resize_method', type=str, default="crop",
help="Resize image method if origin photo larger than imgsz")
# Loss stuff
parser.add_argument('--lr_g', type=float, default=2e-5)
parser.add_argument('--lr_d', type=float, default=4e-5)
parser.add_argument('--init_lr', type=float, default=1e-4)
parser.add_argument('--wadvg', type=float, default=300.0, help='Adversarial loss weight for G')
parser.add_argument('--wadvd', type=float, default=300.0, help='Adversarial loss weight for D')
parser.add_argument(
'--gray_adv', action='store_true',
help="If given, train adversarial with gray scale image instead of RGB image to reduce color effect of anime style")
# Loss weight VGG19
parser.add_argument('--wcon', type=float, default=1.5, help='Content loss weight') # 1.5 for Hayao, 2.0 for Paprika, 1.2 for Shinkai
parser.add_argument('--wgra', type=float, default=5.0, help='Gram loss weight') # 2.5 for Hayao, 0.6 for Paprika, 2.0 for Shinkai
parser.add_argument('--wcol', type=float, default=30.0, help='Color loss weight') # 15. for Hayao, 50. for Paprika, 10. for Shinkai
parser.add_argument('--wtvar', type=float, default=1.0, help='Total variation loss') # 1. for Hayao, 0.1 for Paprika, 1. for Shinkai
parser.add_argument('--d_layers', type=int, default=2, help='Discriminator conv layers')
parser.add_argument('--d_noise', action='store_true')
# DDP
parser.add_argument('--ddp', action='store_true')
parser.add_argument("--local-rank", default=0, type=int)
parser.add_argument("--world-size", default=2, type=int)
return parser.parse_args()
def check_params(args):
# dataset/Hayao + dataset/train_photo -> train_photo_Hayao
args.dataset = f"{os.path.basename(args.real_image_dir)}_{os.path.basename(args.anime_image_dir)}"
assert args.gan_loss in {'lsgan', 'hinge', 'bce'}, f'{args.gan_loss} is not supported'
def main(args, logger):
check_params(args)
if not torch.cuda.is_available():
logger.info("CUDA not found, use CPU")
# Just for debugging purpose, set to minimum config
# to avoid 🔥 the computer...
args.device = 'cpu'
args.debug_samples = 10
args.batch_size = 2
else:
logger.info(f"Use GPU: {torch.cuda.get_device_name(0)}")
norm_type = "instance"
if args.model == 'v1':
G = GeneratorV1(args.dataset)
elif args.model == 'v2':
G = GeneratorV2(args.dataset)
norm_type = "layer"
elif args.model == 'v3':
G = GeneratorV3(args.dataset)
D = Discriminator(
args.dataset,
num_layers=args.d_layers,
use_sn=args.use_sn,
norm_type=norm_type,
)
start_e = 0
start_e_init = 0
trainer = Trainer(
generator=G,
discriminator=D,
config=args,
logger=logger,
)
if args.resume_G_init.lower() != 'false':
start_e_init = load_checkpoint(G, args.resume_G_init) + 1
if args.local_rank == 0:
logger.info(f"G content weight loaded from {args.resume_G_init}")
elif args.resume_G.lower() != 'false' and args.resume_D.lower() != 'false':
# You should provide both
try:
start_e = load_checkpoint(G, args.resume_G)
if args.local_rank == 0:
logger.info(f"G weight loaded from {args.resume_G}")
load_checkpoint(D, args.resume_D)
if args.local_rank == 0:
logger.info(f"D weight loaded from {args.resume_D}")
# If loaded both weight, turn off init G phrase
args.init_epochs = 0
except Exception as e:
print('Could not load checkpoint, train from scratch', e)
elif args.resume:
# Try to load from working dir
logger.info(f"Loading weight from {trainer.checkpoint_path_G}")
start_e = load_checkpoint(G, trainer.checkpoint_path_G)
logger.info(f"Loading weight from {trainer.checkpoint_path_D}")
load_checkpoint(D, trainer.checkpoint_path_D)
args.init_epochs = 0
dataset = AnimeDataSet(
args.anime_image_dir,
args.real_image_dir,
args.debug_samples,
args.cache,
imgsz=args.imgsz,
resize_method=args.resize_method,
)
if args.local_rank == 0:
logger.info(f"Start from epoch {start_e}, {start_e_init}")
trainer.train(dataset, start_e, start_e_init)
if __name__ == '__main__':
args = parse_args()
real_name = os.path.basename(args.real_image_dir)
anime_name = os.path.basename(args.anime_image_dir)
args.exp_dir = f"{args.exp_dir}_{real_name}_{anime_name}"
os.makedirs(args.exp_dir, exist_ok=True)
logger = get_logger(os.path.join(args.exp_dir, "train.log"))
if args.local_rank == 0:
logger.info("# ==== Train Config ==== #")
for arg in vars(args):
logger.info(f"{arg} {getattr(args, arg)}")
logger.info("==========================")
main(args, logger)