-
Notifications
You must be signed in to change notification settings - Fork 74
/
run.py
119 lines (100 loc) · 4.72 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import datetime
import sys
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))
from config import FLAGS
# Save the config.py file for a specific run.
def save_config_file(directory):
# Make the data dir if it doesn't exist.
if not os.path.exists(directory):
os.makedirs(directory)
# This will be used in the names of saved files.
now = datetime.datetime.now()
time_string = (str(now.year) + '.' +
str(now.month) + '.' +
str(now.day) + '.' +
str(now.hour) + '.' +
str(now.minute) + '.' +
str(now.second))
os.system('cp ' + FLAGS['t2t_usr_dir'] + '/config.py ' +
directory + '/config.' + time_string + '.txt')
# Initialize a data generation problem.
def data_generating():
print('Program is running in data generation mode.')
save_config_file(FLAGS['data_dir'])
os.system('t2t-datagen \
--t2t_usr_dir=' + FLAGS['t2t_usr_dir'] +
' --data_dir=' + FLAGS['data_dir'] +
' --problem=' + FLAGS['problem'])
# initialize a training loop with the given flags.
def training():
print('Program is running in training mode.')
save_config_file(FLAGS['train_dir'])
# What hparams should we use.
if FLAGS['hparams'] == '':
hparam_string = 'general_' + FLAGS['model'] + '_hparams'
else:
hparam_string = FLAGS['hparams']
os.system('t2t-trainer \
--generate_data=False \
--t2t_usr_dir=' + FLAGS['t2t_usr_dir'] +
' --data_dir=' + FLAGS['data_dir'] +
' --problem=' + FLAGS['problem'] +
' --output_dir=' + FLAGS['train_dir'] +
' --model=' + FLAGS['model'] +
' --hparams_set=' + hparam_string +
' --schedule=' + FLAGS['train_mode'] +
' --worker_gpu_memory_fraction=' + str(FLAGS['memory_fraction']) +
' --keep_checkpoint_max=' + str(FLAGS['keep_checkpoints']) +
' --keep_checkpoint_every_n_hours=' +
str(FLAGS['save_every_n_hour']) +
' --save_checkpoints_secs=' + str(FLAGS['save_every_n_secs']) +
' --train_steps=' + str(FLAGS['train_steps']) +
' --eval_steps=' + str(FLAGS['evaluation_steps']) +
' --local_eval_frequency=' + str(FLAGS['evaluation_freq']))
# Intialize an inference test with the given flags.
def decoding():
print('Program is running in inference/decoding mode.')
save_config_file(FLAGS['decode_dir'])
# What hparams should we use.
if FLAGS['hparams'] == '':
hparam_string = 'general_' + FLAGS['model'] + '_hparams'
else:
hparam_string = FLAGS['hparams']
decode_mode_string = ''
# Determine the decode mode flag.
if FLAGS['decode_mode'] == 'interactive':
decode_mode_string = ' --decode_interactive'
elif FLAGS['decode_mode'] == 'file':
decode_mode_string = (' --decode_from_file=' +
FLAGS['decode_dir'] + '/' +
FLAGS['input_file_name'])
os.system('t2t-decoder \
--generate_data=False \
--t2t_usr_dir=' + FLAGS['t2t_usr_dir'] +
' --data_dir=' + FLAGS['data_dir'] +
' --problem=' + FLAGS['problem'] +
' --output_dir=' + FLAGS['train_dir'] +
' --model=' + FLAGS['model'] +
' --worker_gpu_memory_fraction=' + str(FLAGS['memory_fraction']) +
' --hparams_set=' + hparam_string +
' --decode_to_file=' +
FLAGS['decode_dir'] + '/' + FLAGS['output_file_name'] +
' --decode_hparams=\'beam_size=' + str(FLAGS['beam_size']) +
',return_beams=' + FLAGS['return_beams'] +
',batch_size=' + str(FLAGS['batch_size']) + '\'' +
decode_mode_string)
# Run a longer experiment, with many calls to the above functions.
def experiment():
ckpt_list = [1, 1328, 2647, 3963, 5284, 6611, 7932, 9254, 10581, 11902, 13227, 14558, 15882, 17209]
dir_list = ["base_with_numbers", "base_both_identity_clustering", "base_source_based_identity_clustering_CORRECT", "base_target_based_identity_clustering",
"base_both_avg_embedding", "base_target_based_avg_embedding", "base_source_based_avg_embedding",
"base_both_sent_eval", "base_target_based_sent_eval", "base_source_based_sent_eval"]
for ckpt in ckpt_list:
#FLAGS["data_dir"] = "data_dir/DailyDialog/" + folder
#FLAGS["train_dir"] = "train_dir/DailyDialog/trf_20_dropout-" + folder
#FLAGS["decode_dir"] = "decode_dir/DailyDialog/trf_20_dropout-" + folder
with open(FLAGS["train_dir"] + "/checkpoint", "w") as ckpt_file:
ckpt_file.write('model_checkpoint_path: "model.ckpt-' + str(ckpt) + '"')
FLAGS["output_file_name"] = "test_set_" + str(ckpt) + ".txt"
decoding()