conda create -n SGA python=3.6
conda install tensorflow==1.11.0
pip install kfac==0.1.0
pip install dm-sonnet==1.27
pip install tensorflow-probability==0.4.0
pip install wrapt
pip install matplotlib
conda env export -n SGA > SGA.yaml
conda env create -n SGA -f SGA.yaml
def train(train_op, x_fake, z, init, disc_loss, gen_loss, z_dim,
n_iter=10001, n_save=2000):
bbox = [-2, 2, -2, 2]
batch_size = x_fake.get_shape()[0].value
ztest = [np.random.randn(batch_size, z_dim) for i in range(10)]
with tf.Session() as sess:
sess.run(init)
for i in range(n_iter):
disc_loss_out, gen_loss_out, _ = sess.run(
[disc_loss, gen_loss, train_op])
if i % n_save == 0:
print('i = %d, discriminant loss = %.4f, generator loss =%.4f' %
(i, disc_loss_out, gen_loss_out))
x_out = np.concatenate(
[sess.run(x_fake, feed_dict={z: zt}) for zt in ztest], axis=0)
kde(x_out[:, 0], x_out[:, 1], bbox=bbox)
# 将生成的图片保存
plt.savefig("./{}_{}.png".format("SGA",i))