Skip to content

Latest commit

 

History

History
45 lines (39 loc) · 1.34 KB

测试笔记.md

File metadata and controls

45 lines (39 loc) · 1.34 KB

SGA 配置代码

conda create -n SGA python=3.6
conda install tensorflow==1.11.0
pip install kfac==0.1.0
pip install dm-sonnet==1.27
pip install tensorflow-probability==0.4.0
pip install wrapt
pip install matplotlib

将服务器A上的SGA环境克隆至 .yaml 文件

conda env export -n SGA > SGA.yaml


在服务器B上进行创建该配置的环境

conda env create -n SGA -f SGA.yaml


将生成的图像保存

def train(train_op, x_fake, z, init, disc_loss, gen_loss, z_dim,
          n_iter=10001, n_save=2000):
    bbox = [-2, 2, -2, 2]
    batch_size = x_fake.get_shape()[0].value
    ztest = [np.random.randn(batch_size, z_dim) for i in range(10)]

    with tf.Session() as sess:
        sess.run(init)

        for i in range(n_iter):
            disc_loss_out, gen_loss_out, _ = sess.run(
                [disc_loss, gen_loss, train_op])
            if i % n_save == 0:
                print('i = %d, discriminant loss = %.4f, generator loss =%.4f' %
                      (i, disc_loss_out, gen_loss_out))
                x_out = np.concatenate(
                    [sess.run(x_fake, feed_dict={z: zt}) for zt in ztest], axis=0)
                
                kde(x_out[:, 0], x_out[:, 1], bbox=bbox)
                # 将生成的图片保存
                plt.savefig("./{}_{}.png".format("SGA",i))