-
Notifications
You must be signed in to change notification settings - Fork 83
/
Copy pathtraining_vqgan.py
120 lines (95 loc) · 5.79 KB
/
training_vqgan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import argparse
from tqdm import tqdm
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import utils as vutils
from discriminator import Discriminator
from lpips import LPIPS
from vqgan import VQGAN
from utils import load_data, weights_init
class TrainVQGAN:
def __init__(self, args):
self.vqgan = VQGAN(args).to(device=args.device)
self.discriminator = Discriminator(args).to(device=args.device)
self.discriminator.apply(weights_init)
self.perceptual_loss = LPIPS().eval().to(device=args.device)
self.opt_vq, self.opt_disc = self.configure_optimizers(args)
self.prepare_training()
self.train(args)
def configure_optimizers(self, args):
lr = args.learning_rate
opt_vq = torch.optim.Adam(
list(self.vqgan.encoder.parameters()) +
list(self.vqgan.decoder.parameters()) +
list(self.vqgan.codebook.parameters()) +
list(self.vqgan.quant_conv.parameters()) +
list(self.vqgan.post_quant_conv.parameters()),
lr=lr, eps=1e-08, betas=(args.beta1, args.beta2)
)
opt_disc = torch.optim.Adam(self.discriminator.parameters(),
lr=lr, eps=1e-08, betas=(args.beta1, args.beta2))
return opt_vq, opt_disc
@staticmethod
def prepare_training():
os.makedirs("results", exist_ok=True)
os.makedirs("checkpoints", exist_ok=True)
def train(self, args):
train_dataset = load_data(args)
steps_per_epoch = len(train_dataset)
for epoch in range(args.epochs):
with tqdm(range(len(train_dataset))) as pbar:
for i, imgs in zip(pbar, train_dataset):
imgs = imgs.to(device=args.device)
decoded_images, _, q_loss = self.vqgan(imgs)
disc_real = self.discriminator(imgs)
disc_fake = self.discriminator(decoded_images)
disc_factor = self.vqgan.adopt_weight(args.disc_factor, epoch*steps_per_epoch+i, threshold=args.disc_start)
perceptual_loss = self.perceptual_loss(imgs, decoded_images)
rec_loss = torch.abs(imgs - decoded_images)
perceptual_rec_loss = args.perceptual_loss_factor * perceptual_loss + args.rec_loss_factor * rec_loss
perceptual_rec_loss = perceptual_rec_loss.mean()
g_loss = -torch.mean(disc_fake)
λ = self.vqgan.calculate_lambda(perceptual_rec_loss, g_loss)
vq_loss = perceptual_rec_loss + q_loss + disc_factor * λ * g_loss
d_loss_real = torch.mean(F.relu(1. - disc_real))
d_loss_fake = torch.mean(F.relu(1. + disc_fake))
gan_loss = disc_factor * 0.5*(d_loss_real + d_loss_fake)
self.opt_vq.zero_grad()
vq_loss.backward(retain_graph=True)
self.opt_disc.zero_grad()
gan_loss.backward()
self.opt_vq.step()
self.opt_disc.step()
if i % 10 == 0:
with torch.no_grad():
real_fake_images = torch.cat((imgs[:4], decoded_images.add(1).mul(0.5)[:4]))
vutils.save_image(real_fake_images, os.path.join("results", f"{epoch}_{i}.jpg"), nrow=4)
pbar.set_postfix(
VQ_Loss=np.round(vq_loss.cpu().detach().numpy().item(), 5),
GAN_Loss=np.round(gan_loss.cpu().detach().numpy().item(), 3)
)
pbar.update(0)
torch.save(self.vqgan.state_dict(), os.path.join("checkpoints", f"vqgan_epoch_{epoch}.pt"))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="VQGAN")
parser.add_argument('--latent-dim', type=int, default=256, help='Latent dimension n_z (default: 256)')
parser.add_argument('--image-size', type=int, default=256, help='Image height and width (default: 256)')
parser.add_argument('--num-codebook-vectors', type=int, default=1024, help='Number of codebook vectors (default: 256)')
parser.add_argument('--beta', type=float, default=0.25, help='Commitment loss scalar (default: 0.25)')
parser.add_argument('--image-channels', type=int, default=3, help='Number of channels of images (default: 3)')
parser.add_argument('--dataset-path', type=str, default='/data', help='Path to data (default: /data)')
parser.add_argument('--device', type=str, default="cuda", help='Which device the training is on')
parser.add_argument('--batch-size', type=int, default=6, help='Input batch size for training (default: 6)')
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs to train (default: 50)')
parser.add_argument('--learning-rate', type=float, default=2.25e-05, help='Learning rate (default: 0.0002)')
parser.add_argument('--beta1', type=float, default=0.5, help='Adam beta param (default: 0.0)')
parser.add_argument('--beta2', type=float, default=0.9, help='Adam beta param (default: 0.999)')
parser.add_argument('--disc-start', type=int, default=10000, help='When to start the discriminator (default: 0)')
parser.add_argument('--disc-factor', type=float, default=1., help='')
parser.add_argument('--rec-loss-factor', type=float, default=1., help='Weighting factor for reconstruction loss.')
parser.add_argument('--perceptual-loss-factor', type=float, default=1., help='Weighting factor for perceptual loss.')
args = parser.parse_args()
args.dataset_path = r"C:\Users\dome\datasets\flowers"
train_vqgan = TrainVQGAN(args)