This repository has been archived by the owner on Sep 25, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2
/
gan.py
144 lines (118 loc) · 5.25 KB
/
gan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os
import sys
import argparse
import numpy as np
import mindspore
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import ms_function
from tqdm import tqdm
sys.path.append(os.pardir)
from layers import Dense
from img_utils import to_image
from dataset import create_dataset
os.makedirs("images", exist_ok=True)
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args()
img_shape = (opt.channels, opt.img_size, opt.img_size)
latent_dim = opt.latent_dim
class Generator(nn.Cell):
def __init__(self):
super(Generator, self).__init__()
def block(in_feat, out_feat, normalize=True):
layers = [Dense(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2))
return layers
self.model = nn.SequentialCell(
*block(opt.latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
Dense(1024, int(np.prod(img_shape))),
nn.Tanh()
)
def construct(self, z):
img = self.model(z)
img = img.view(img.shape[0], *img_shape)
return img
class Discriminator(nn.Cell):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.SequentialCell(
Dense(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2),
Dense(512, 256),
nn.LeakyReLU(0.2),
Dense(256, 1),
nn.Sigmoid(),
)
def construct(self, img):
img_flat = img.view(img.shape[0], -1)
validity = self.model(img_flat)
return validity
# Loss function
adversarial_loss = nn.BCELoss(reduction='mean')
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()
generator.update_parameters_name('generator')
discriminator.update_parameters_name('discriminator')
generator.set_train()
discriminator.set_train()
# Optimizers
optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=opt.lr, beta1=opt.b1, beta2=opt.b2)
optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=opt.lr, beta1=opt.b1, beta2=opt.b2)
optimizer_G.update_parameters_name('optim_g')
optimizer_D.update_parameters_name('optim_d')
def generator_forward(real_imgs, valid):
# Sample noise as generator input
z = ops.StandardNormal()((real_imgs.shape[0], latent_dim))
# Generate a batch of images
gen_imgs = generator(z)
# Loss measures generator's ability to fool the discriminator
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
return g_loss, gen_imgs
def discriminator_forward(real_imgs, gen_imgs, valid, fake):
# Measure discriminator's ability to classify real from generated samples
real_loss = adversarial_loss(discriminator(real_imgs), valid)
fake_loss = adversarial_loss(discriminator(gen_imgs), fake)
d_loss = (real_loss + fake_loss) / 2
return d_loss
grad_generator_fn = ops.value_and_grad(generator_forward, None,
optimizer_G.parameters,
has_aux=True)
grad_discriminator_fn = ops.value_and_grad(discriminator_forward, None,
optimizer_D.parameters)
@ms_function
def train_step(imgs):
valid = ops.ones((imgs.shape[0], 1), mindspore.float32)
fake = ops.zeros((imgs.shape[0], 1), mindspore.float32)
(g_loss, gen_imgs), g_grads = grad_generator_fn(imgs, valid)
optimizer_G(g_grads)
d_loss, d_grads = grad_discriminator_fn(imgs, gen_imgs, valid, fake)
optimizer_D(d_grads)
return g_loss, d_loss, gen_imgs
dataset = create_dataset('../../dataset', 'train', opt.img_size, opt.batch_size, num_parallel_workers=opt.n_cpu)
dataset_size = dataset.get_dataset_size()
for epoch in range(opt.n_epochs):
t = tqdm(total=dataset_size)
t.set_description('Epoch %i' % epoch)
for i, (imgs, _) in enumerate(dataset.create_tuple_iterator()):
g_loss, d_loss, gen_imgs = train_step(imgs)
t.set_postfix(g_loss=g_loss, d_loss=d_loss)
t.update(1)
batches_done = epoch * dataset_size + i
if batches_done % opt.sample_interval == 0:
to_image(gen_imgs[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)