-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain
executable file
·202 lines (157 loc) · 8.57 KB
/
train
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
#!/usr/bin/python
import os, sys
import time
import argparse
import numpy as np
import logging
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config-file', dest='config_file', required=True, type=str, help='Input configuration file')
parser.add_argument('-o', '--output-directory', dest='output_dir', required=True, type=str, help='Path to output directory')
parser.add_argument('-d', '--device', dest='device', default='cpu', help='Theano device')
parser.add_argument('--overwrite', dest='overwrite', action='store_true', help='Enable this flag for overwriting the same model file after every epoch')
parser.add_argument('-v', '--verbose', dest='verbose', action='store_true', help='Enable this flag for debugging')
args = parser.parse_args()
# TODO: implement training parameters here
# making the data directory
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
else:
# TODO: prompt overwriting
pass
# setting the logger
from deeptxt.utils import logutils
logger = logging.getLogger(__name__)
logutils.set_logger(args.output_dir)
logger.info("Importing backend: Theano")
from deeptxt.utils import theano_config
if args.device == 'cpu':
logger.warning("Using CPU. To use GPU, use '--device cuda<device-number>'")
theano_config.configure(device=args.device, dtype='float32', verbose=args.verbose)
from deeptxt.trainer import Trainer
from deeptxt.evaluator import Evaluator
from deeptxt.io.vocab_manager import VocabManager
from deeptxt.utils.prepare_vocabulary import write_vocab
if args.config_file is not None:
if args.config_file.split('.')[-1] == 'ini':
from deeptxt.config.config_parser import IniParser
config = IniParser(args.config_file)
elif args.config_file.split('.')[-1] == 'yml':
raise NotImplementedError
else:
logger.error('Config file should have .ini extension and INI format')
sys.exit()
else:
logger.error('Missing config file!')
sys.exit()
#############################################
###### MODEL INITIALIZATION ##################
##############################################
## RNN Language Model
logger.info('Initializing model')
if config.model == 'rnnlm':
from deeptxt.models.rnn_language_model import RNNLanguageModel
from deeptxt.io.readers.text_reader import TextReader
train_data = TextReader(dataset_path=config.train_params.train_data, batchsize=config.train_params.batchsize, max_length=config.train_params.max_length)
validation_data = TextReader(dataset_path=config.train_params.validation_data, batchsize=config.train_params.validation_batchsize, max_length=config.train_params.max_length)
if not hasattr(config.train_params, 'vocab_path'):
setattr(config.train_params, 'vocab_path', args.output_dir + '/vocab.txt')
write_vocab(config.train_params.train_data, config.train_params.vocab_path)
vocab = VocabManager(config.train_params.vocab_path, config.network_params.vocab_size)
model = RNNLanguageModel(hyperparams=config.network_params, vocab=vocab)
# Seq2Seq Models
elif config.model == 'seq2seq':
from deeptxt.models.rnn_encoder_decoder import RNNEncoderDecoder as EncDec
from deeptxt.io.readers.parallel_text_reader import ParallelTextReader
train_data = ParallelTextReader(source_dataset_path=config.train_params.source_train_data,
target_dataset_path=config.train_params.target_train_data,
batchsize=config.train_params.batchsize, source_max_length=config.train_params.max_length, target_max_length=config.train_params.max_length)
validation_data = ParallelTextReader(source_dataset_path=config.train_params.source_validation_data, target_dataset_path=config.train_params.target_validation_data, batchsize=config.train_params.validation_batchsize, source_max_length=config.train_params.max_length, target_max_length=config.train_params.max_length)
if not hasattr(config.train_params, 'source_vocab_path'):
setattr(config.train_params, 'source_vocab_path', args.output_dir + '/source_vocab.txt')
write_vocab(config.train_params.source_train_data, config.train_params.source_vocab_path)
if not hasattr(config.train_params, 'target_vocab_path'):
setattr(config.train_params, 'target_vocab_path', args.output_dir + '/target_vocab.txt')
write_vocab(config.train_params.target_train_data, config.train_params.target_vocab_path)
source_vocab = VocabManager(config.train_params.source_vocab_path, config.network_params.encoder_vocab_size)
target_vocab = VocabManager(config.train_params.target_vocab_path, config.network_params.decoder_vocab_size)
model = EncDec(hyperparams=config.network_params, encoder_vocab=source_vocab, decoder_vocab=target_vocab, encode_bidirectional=config.network_params.encode_bidirectional, use_attention=config.network_params.use_attention)
else:
logger.error("Invalid Model name:" + config.model)
logger.error('Ensure that there are no quotations for the strings in the config files')
sys.exit()
############
## Training
############
num_samples = 0
### Setting the optimizer
if config.train_params.optimizer == 'sgd':
from deeptxt.optimizers.sgd import SGD as Optimizer
elif config.train_params.optimizer == 'adagrad':
from deeptxt.optimizers.adagrad import Adagrad as Optimizer
elif config.train_params.optimizer == 'rmsprop':
from deeptxt.optimizers.rmsprop import RMSProp as Optimizer
elif config.train_params.optimizer == 'adadelta':
from deeptxt.optimizers.adadelta import Adadelta as Optimizer
elif config.train_params.optimizer == 'adam':
from deeptxt.optimizers.adam import Adam as Optimizer
else:
raise NotImplementedError
### Building model
logger.info('Building model graph')
model.build()
### Preparing trainer and validator objects
logger.info('Building trainer')
kwargs_optimizer = dict()
if hasattr(config.train_params, 'clip_norm'):
kwargs_optimizer['clip_norm'] = config.train_params.clip_norm
if hasattr(config.train_params, 'clip_value'):
kwargs_optimizer['clip_value'] = config.train_params.clip_value
optimizer = Optimizer(learning_rate=config.train_params.learning_rate, **kwargs_optimizer)
trainer = Trainer(model=model, optimizer=optimizer, cache_minibatches=True)
logger.info('Building validator')
validator = Evaluator(data_reader=validation_data, model=model)
if hasattr(config.train_params, 'init_model'):
logger.info('Initializing parameters from model:' + config.train_params.init_model)
if os.path.exists(config.train_params.init_model):
trainer.load_model(config.train_params.init_model)
else:
logger.error('Invalid model path: %s' % config.train_params.init_model)
for epoch in xrange(config.train_params.num_epochs):
train_loss_sum = 0
train_batches = 0
while True:
# Update the paramaters after computing loss
train_loss = trainer.update(data_reader=train_data, verbose=args.verbose, minibatch_index=train_batches)
# end of epoch is reached
if train_loss is None:
break
# Check if loss is NaN or infitnity
if np.isnan(train_loss) or np.isinf(train_loss):
logger.error("NaN detected at Update:" + str(trainer.num_updates))
# Incrementing updates and adding loss
train_loss_sum += train_loss
train_batches += 1
# displaying
if config.train_params.display_freq > 0 and trainer.num_updates % config.train_params.display_freq == 0:
trainer.log_train_info(epoch=epoch, loss=train_loss_sum, num_batches=train_batches)
# validation
if config.train_params.validation_freq > 0 and trainer.num_updates % config.train_params.validation_freq == 0:
validator.log_validation_info()
# saving
if config.train_params.save_freq > 0 and trainer.num_updates % config.train_params.save_freq == 0:
model_path = args.output_dir + '/model_iter' + str(trainer.num_updates) + '.npz'
logger.info("Saving model: " + model_path)
trainer.save_model(model_path, validator)
# sampling
if config.train_params.sampling_freq > 0 and trainer.num_updates % config.train_params.sampling_freq == 0:
trainer.sampling(num_samples=1)
# dispalying
trainer.log_train_info(epoch=epoch, loss=train_loss_sum, num_batches=train_batches)
# validation
validator.log_validation_info()
logger.info("Epoch %d completed. Saving." %(epoch))
# saving
if args.overwrite == True:
trainer.save_model(args.output_dir + '/model.npz', validator)
else:
trainer.save_model(args.output_dir + '/model_epoch' + str(epoch) + '.npz', validator)