-
Notifications
You must be signed in to change notification settings - Fork 1
/
gansynth_train.py
142 lines (112 loc) · 4.33 KB
/
gansynth_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Copyright 2019 The Magenta Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file has been altered from the original version.
r"""Train a progressive GANSynth model.
Example usage: (From base directory)
>>> python magenta/models/gansynth/train.py
To use a config of hyperparameters:
>>> python magenta/models/gansynth/train.py --config=mel_prog_hires
To use a config of hyperparameters and manual hparams:
>>> python magenta/models/gansynth/train.py --config=mel_prog_hires \
>>> --hparams='{"train_data_path":"/path/to/nsynth-train.tfrecord"}'
List of hyperparameters can be found in model.py.
Trains in a couple days on a single V100 GPU.
Adapted from the original Progressive GAN paper for images.
See https://arxiv.org/abs/1710.10196 for details about the model.
See https://github.com/tkarras/progressive_growing_of_gans for the original
theano implementation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import importlib
import json
import os
import time
from absl import logging
import absl.flags
import lib.data_helpers as data_helpers
import lib.data_normalizer as data_normalizer
import lib.flags as lib_flags
import lib.model as lib_model
import lib.train_util as train_util
import lib.util as util
import tensorflow as tf
absl.flags.DEFINE_string('hparams', '{}', 'Flags dict as JSON string.')
absl.flags.DEFINE_string('config', '', 'Name of config module.')
FLAGS = absl.flags.FLAGS
tf.logging.set_verbosity(tf.logging.INFO)
def init_data_normalizer(config):
"""Initializes data normalizer."""
normalizer = data_normalizer.registry[config['data_normalizer']](config)
if normalizer.exists():
return
if config['task'] == 0:
tf.reset_default_graph()
data_helper = data_helpers.registry[config['data_type']](config)
real_images, _ = data_helper.provide_data(batch_size=10)
# Save normalizer.
# Note if normalizer has been saved, save() is no-op. To regenerate the
# normalizer, delete the normalizer file in train_root_dir/assets
normalizer.save(real_images)
else:
while not normalizer.exists():
time.sleep(5)
def run(config):
"""Entry point to run training."""
init_data_normalizer(config)
stage_ids = train_util.get_stage_ids(**config)
if not config['train_progressive']:
stage_ids = list(stage_ids)[-1:]
# Train one stage at a time
for stage_id in stage_ids:
batch_size = train_util.get_batch_size(stage_id, **config)
tf.reset_default_graph()
with tf.device(tf.train.replica_device_setter(config['ps_tasks'])):
model = lib_model.Model(stage_id, batch_size, config)
model.add_summaries(stage_id)
print('Variables:')
for v in tf.global_variables():
print('\t', v.name, v.get_shape().as_list())
logging.info('Calling train.train')
train_util.train(model, **config)
def main(unused_argv):
absl.flags.FLAGS.alsologtostderr = True
# Set hyperparams from json args and defaults
flags = lib_flags.Flags()
# Config hparams
if FLAGS.config:
config_module = importlib.import_module(
'magenta.models.gansynth.configs.{}'.format(FLAGS.config))
flags.load(config_module.hparams)
# Command line hparams
flags.load_json(FLAGS.hparams)
# Set default flags
lib_model.set_flags(flags)
print('Flags:')
flags.print_values()
# Create training directory
flags['train_root_dir'] = util.expand_path(flags['train_root_dir'])
if not tf.gfile.Exists(flags['train_root_dir']):
tf.gfile.MakeDirs(flags['train_root_dir'])
# Save the flags to help with loading the model latter
fname = os.path.join(flags['train_root_dir'], 'experiment.json')
with tf.gfile.Open(fname, 'w') as f:
json.dump(flags, f) # pytype: disable=wrong-arg-types
# Run training
run(flags)
def console_entry_point():
tf.app.run(main)
if __name__ == '__main__':
console_entry_point()