-
Notifications
You must be signed in to change notification settings - Fork 3
/
train.py
111 lines (77 loc) · 3.5 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import pickle
import numpy as np
import tensorflow as tf
import config as conf
from data.mnist import Mnist
from models.vae import VAE
def main():
model_spec_name = "%s-model-spec.json" % conf.MODEL_NAME
model_rslt_name = "%s-results.pickle" % conf.MODEL_NAME
model_save_path = os.path.join(conf.MODEL_SAVE_DIR, conf.MODEL_NAME)
if not os.path.exists(model_save_path):
os.makedirs(model_save_path)
model_ckpt_path = os.path.join(model_save_path, "model-ckpt")
model_spec_path = os.path.join(model_save_path, model_spec_name)
model_rslt_path = os.path.join(model_save_path, model_rslt_name)
loader = Mnist()
features = np.vstack([loader.train_features, loader.test_features]).astype(np.float32)
num_sets = loader.num_train_sets + loader.num_test_sets
feature_depth = loader.feature_depth
feature_shape = loader.feature_shape
latent_depth = conf.LATENT_DEPTH
batch_size = conf.BATCH_SIZE
num_epochs = conf.NUM_EPOCHS
model = VAE(latent_depth, feature_depth)
opt = tf.keras.optimizers.Adam()
@tf.function
def train_step(x, eps):
with tf.GradientTape() as tape:
mu, log_sigma = model.encode(x, training=True)
z = model.reparam(eps, mu, log_sigma)
f_z = model.decode(z, training=True)
encoder_loss = tf.reduce_mean(model.encoder_loss(mu, log_sigma))
decoder_loss = tf.reduce_mean(model.decoder_loss(x, f_z))
loss = encoder_loss + decoder_loss
grads_loss = tape.gradient(
target=loss, sources=model.encoder.trainable_variables+model.decoder.trainable_variables)
opt.apply_gradients(
zip(grads_loss, model.encoder.trainable_variables+model.decoder.trainable_variables)
)
return encoder_loss, decoder_loss, loss
ckpt = tf.train.Checkpoint(encoder=model.encoder, decoder=model.decoder)
steps_per_epoch = num_sets // batch_size
train_steps = steps_per_epoch * num_epochs
encoder_losses = []
decoder_losses = []
losses = []
encoder_losses_epoch = []
decoder_losses_epoch = []
losses_epoch = []
fs = []
for i in range(1, train_steps+1):
epoch = i // steps_per_epoch
idxes = np.random.choice(num_sets, batch_size, replace=False)
x_i = features[idxes]
eps_i = np.random.normal(size=[batch_size, latent_depth]).astype(np.float32)
encoder_loss_i, decoder_loss_i, loss_i = train_step(x_i, eps_i)
encoder_losses.append(encoder_loss_i)
decoder_losses.append(decoder_loss_i)
losses.append(loss_i)
if i % steps_per_epoch == 0:
f_eps = model.decode(eps_i, training=False)
encoder_loss_epoch = np.mean(encoder_losses[-steps_per_epoch:])
decoder_loss_epoch = np.mean(decoder_losses[-steps_per_epoch:])