-
Notifications
You must be signed in to change notification settings - Fork 0
/
dcgan.cpp
205 lines (177 loc) · 7.74 KB
/
dcgan.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
#pragma once
#include <torch/torch.h>
#include <cmath>
#include <cstdio>
#include <iostream>
// The size of the noise vector fed to the generator.
const int64_t kNoiseSize = 100;
// The batch size for training.
const int64_t kBatchSize = 64;
// The number of epochs to train.
const int64_t kNumberOfEpochs = 30;
// Where to find the MNIST dataset.
const char* kDataFolder = "./data";
// After how many batches to create a new checkpoint periodically.
const int64_t kCheckpointEvery = 200;
// How many images to sample at every checkpoint.
const int64_t kNumberOfSamplesPerCheckpoint = 10;
// Set to `true` to restore models and optimizers from previously saved
// checkpoints.
const bool kRestoreFromCheckpoint = false;
// After how many batches to log a new update with the loss value.
const int64_t kLogInterval = 10;
using namespace torch;
struct DCGANGeneratorImpl : nn::Module {
DCGANGeneratorImpl(int kNoiseSize)
: conv1(nn::ConvTranspose2dOptions(kNoiseSize, 256, 4)
.bias(false)),
batch_norm1(256),
conv2(nn::ConvTranspose2dOptions(256, 128, 3)
.stride(2)
.padding(1)
.bias(false)),
batch_norm2(128),
conv3(nn::ConvTranspose2dOptions(128, 64, 4)
.stride(2)
.padding(1)
.bias(false)),
batch_norm3(64),
conv4(nn::ConvTranspose2dOptions(64, 1, 4)
.stride(2)
.padding(1)
.bias(false))
{
// register_module() is needed if we want to use the parameters() method later on
register_module("conv1", conv1);
register_module("conv2", conv2);
register_module("conv3", conv3);
register_module("conv4", conv4);
register_module("batch_norm1", batch_norm1);
register_module("batch_norm2", batch_norm2);
register_module("batch_norm3", batch_norm3);
}
torch::Tensor forward(torch::Tensor x) {
x = torch::relu(batch_norm1(conv1(x)));
x = torch::relu(batch_norm2(conv2(x)));
x = torch::relu(batch_norm3(conv3(x)));
x = torch::tanh(conv4(x));
return x;
}
nn::ConvTranspose2d conv1, conv2, conv3, conv4;
nn::BatchNorm2d batch_norm1, batch_norm2, batch_norm3;
};
TORCH_MODULE(DCGANGenerator);
int main_example(int argc, const char* argv[]) {
torch::manual_seed(1);
// Create the device we pass around based on whether CUDA is available.
torch::Device device(torch::kCPU);
if (torch::cuda::is_available()) {
std::cout << "CUDA is available! Training on GPU." << std::endl;
device = torch::Device(torch::kCUDA);
}
DCGANGenerator generator(kNoiseSize);
generator->to(device);
nn::Sequential discriminator(
// Layer 1
nn::Conv2d(
nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).bias(false)),
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
// Layer 2
nn::Conv2d(
nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).bias(false)),
nn::BatchNorm2d(128),
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
// Layer 3
nn::Conv2d(
nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).bias(false)),
nn::BatchNorm2d(256),
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
// Layer 4
nn::Conv2d(
nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).bias(false)),
nn::Sigmoid());
discriminator->to(device);
// Assume the MNIST dataset is available under `kDataFolder`;
auto dataset = torch::data::datasets::MNIST(kDataFolder)
.map(torch::data::transforms::Normalize<>(0.5, 0.5))
.map(torch::data::transforms::Stack<>());
const int64_t batches_per_epoch =
std::ceil(dataset.size().value() / static_cast<double>(kBatchSize));
std::unique_ptr<torch::data::StatelessDataLoader<decltype(dataset), torch::data::samplers::RandomSampler>>
data_loader = torch::data::make_data_loader(
std::move(dataset),
torch::data::DataLoaderOptions().batch_size(kBatchSize).workers(2));
torch::optim::Adam generator_optimizer(
generator->parameters(), torch::optim::AdamOptions(2e-4).betas(std::make_tuple(0.5, 0.5)));
torch::optim::Adam discriminator_optimizer(
discriminator->parameters(), torch::optim::AdamOptions(2e-4).betas(std::make_tuple(0.5, 0.5)));
if (kRestoreFromCheckpoint) {
torch::load(generator, "generator-checkpoint.pt");
torch::load(generator_optimizer, "generator-optimizer-checkpoint.pt");
torch::load(discriminator, "discriminator-checkpoint.pt");
torch::load(
discriminator_optimizer, "discriminator-optimizer-checkpoint.pt");
}
int64_t checkpoint_counter = 1;
for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
int64_t batch_index = 0;
for (torch::data::Example<>& batch : *data_loader) {
// Train discriminator with real images.
discriminator->zero_grad();
torch::Tensor real_images = batch.data.to(device);
torch::Tensor real_labels =
torch::empty(batch.data.size(0), device).uniform_(0.8, 1.0);
torch::Tensor real_output = discriminator->forward(real_images);
torch::Tensor d_loss_real =
torch::binary_cross_entropy(real_output, real_labels);
d_loss_real.backward();
// Train discriminator with fake images.
torch::Tensor noise =
torch::randn({ batch.data.size(0), kNoiseSize, 1, 1 }, device);
torch::Tensor fake_images = generator->forward(noise);
torch::Tensor fake_labels = torch::zeros(batch.data.size(0), device);
torch::Tensor fake_output = discriminator->forward(fake_images.detach());
torch::Tensor d_loss_fake =
torch::binary_cross_entropy(fake_output, fake_labels);
d_loss_fake.backward();
torch::Tensor d_loss = d_loss_real + d_loss_fake;
discriminator_optimizer.step();
// Train generator.
generator->zero_grad();
fake_labels.fill_(1);
fake_output = discriminator->forward(fake_images);
torch::Tensor g_loss =
torch::binary_cross_entropy(fake_output, fake_labels);
g_loss.backward();
generator_optimizer.step();
batch_index++;
if (batch_index % kLogInterval == 0) {
std::printf(
"\r[%2ld/%2ld][%3ld/%3ld] D_loss: %.4f | G_loss: %.4f\n",
epoch,
kNumberOfEpochs,
batch_index,
batches_per_epoch,
d_loss.item<float>(),
g_loss.item<float>());
}
if (batch_index % kCheckpointEvery == 0) {
// Checkpoint the model and optimizer state.
torch::save(generator, "generator-checkpoint.pt");
torch::save(generator_optimizer, "generator-optimizer-checkpoint.pt");
torch::save(discriminator, "discriminator-checkpoint.pt");
torch::save(
discriminator_optimizer, "discriminator-optimizer-checkpoint.pt");
// Sample the generator and save the images.
torch::Tensor samples = generator->forward(torch::randn(
{ kNumberOfSamplesPerCheckpoint, kNoiseSize, 1, 1 }, device));
torch::save(
(samples + 1.0) / 2.0,
torch::str("dcgan-sample-", checkpoint_counter, ".pt"));
std::cout << "\n-> checkpoint " << ++checkpoint_counter << '\n';
}
}
}
std::cout << "Training complete!" << std::endl;
return 0;
}