-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmain.py
64 lines (53 loc) · 1.78 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""
Training main function for DECAF.
Created by Renhao Liu, CIG, WUSTL, 2021.
"""
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
import logging
import tensorflow as tf
tf.get_logger().setLevel(logging.ERROR)
import warnings
warnings.filterwarnings("ignore")
import h5py
from absl import app
from absl import flags
from model.model import Model
from model.provider import DecafEndToEndProvider
FLAGS = flags.FLAGS
# Version parameters
flags.DEFINE_string("name", "DECAF", "model name")
# Replace current parameters.
flags.DEFINE_string("input_dir", "", "input_file")
flags.DEFINE_string("model_save_dir", "saved_model", "directory for saving model")
# Training hyper parameters
flags.DEFINE_integer("epochs", 40000, "number of training epochs")
flags.DEFINE_float("start_lr", 8e-4, "number of training epochs")
flags.DEFINE_float("end_lr", 3e-4, "number of training epochs")
def main(argv):
"""
DECAF main function
"""
lr_start = FLAGS.start_lr
lr_end = FLAGS.end_lr
epochs = FLAGS.epochs
multiplier = (lr_end / lr_start) ** (1 / epochs)
decayed_lr = [lr_start * (multiplier ** x) for x in range(epochs)]
train_kargs = {
"epochs": epochs,
"learning_rate": decayed_lr,
}
directory = FLAGS.model_save_dir
if not os.path.exists(directory):
os.makedirs(directory)
config_file = open(directory + "/config.txt", "w")
config_file.write(FLAGS.flags_into_string())
config_file.close()
h5_file = h5py.File(FLAGS.input_dir, "r")
train_provider = DecafEndToEndProvider(h5_file)
model = Model(name=FLAGS.name)
model.train(FLAGS.model_save_dir, train_provider, **train_kargs)
if __name__ == "__main__":
app.run(main)