-
Notifications
You must be signed in to change notification settings - Fork 14
/
main.py
43 lines (35 loc) · 1.14 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import tensorflow_probability as tfp
tfd = tfp.distributions
import pygrid as pygrid
from absl import app
import tensorflow.compat.v2 as tf
from train_utils import *
from train import Trainer
from train_distributed import Trainer as Trainer_dist
def main(argv):
del argv
LARGE_DATASETS = ["celebahq128", "lsun_bedroom128", "lsun_bedroom64", 'lsun_church128', 'lsun_church64', 'celeba']
exp_id = pygrid.get_exp_id(__file__)
output_dir = pygrid.get_output_dir(exp_id, './')
if FLAGS.problem in LARGE_DATASETS:
FLAGS.fid_n_samples = 2560
FLAGS.fid_n_batch = 640
elif FLAGS.problem == 'celebahq256':
FLAGS.fid_n_samples = 1280
FLAGS.fid_n_batch = 160
hps = AttrDict(get_flag_dict())
hps.output = output_dir
if hps.device:
set_gpu(hps.device)
init_tf2(tf_eager=hps.eager, tf_memory_growth=True)
if hps.tpu:
resolver = setup_tpu()
strategy = tf.distribute.experimental.TPUStrategy(resolver)
model = Trainer_dist(hps=hps)
else:
strategy = None
model = Trainer(hps=hps)
set_seed(hps.rnd_seed)
model.train(output_dir, output_dir, output_dir, strategy)
if __name__ == '__main__':
app.run(main)