-
Notifications
You must be signed in to change notification settings - Fork 134
/
runXvector.py
297 lines (250 loc) · 14 KB
/
runXvector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
# -*- coding:utf-8 -*-
# Copyright xmuspeech (Author: Snowdar 2020-02-06)
# Apache 2.0
# This script just support singel-GPU training and it is a simple example of standard x-vector.
# For more, see runSnowdarXvector.py and runResnetXvector.py.
import sys, os
import logging
import argparse
import traceback
import time
import math
import numpy as np
import torch
sys.path.insert(0, 'subtools/pytorch')
import libs.egs.egs as egs
import libs.training.optim as optim
import libs.training.lr_scheduler as learn_rate_scheduler
import libs.training.trainer as trainer
import libs.support.kaldi_common as kaldi_common
import libs.support.utils as utils
# Logger
logger = logging.getLogger('libs')
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s [ %(pathname)s:%(lineno)s - "
"%(funcName)s - %(levelname)s ]\n#### %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
# Parser: add this parser to run launcher with some frequent options (really for conveninece).
parser = argparse.ArgumentParser(
description="""Train xvector framework with pytorch.""",
formatter_class=argparse.RawTextHelpFormatter,
conflict_handler='resolve')
parser.add_argument("--stage", type=int, default=3,
help="The stage to control the start of training epoch (default 3).\n"
" stage 0: vad-cmn (preprocess_to_egs.sh).\n"
" stage 1: remove utts (preprocess_to_egs.sh).\n"
" stage 2: get chunk egs (preprocess_to_egs.sh).\n"
" stage 3: training.\n"
" stage 4: extract xvector.")
parser.add_argument("--endstage", type=int, default=4,
help="The endstage to control the endstart of training epoch (default 4).")
parser.add_argument("--train-stage", type=int, default=-1,
help="The stage to control the start of training epoch (default -1).\n"
" -1 -> creating model_dir.\n"
" 0 -> model initialization (e.g. transfer learning).\n"
" >0 -> recovering training.")
parser.add_argument("--force-clear", type=str, action=kaldi_common.StrToBoolAction,
default=False, choices=["true", "false"],
help="Clear the dir generated by preprocess.")
parser.add_argument("--use-gpu", type=str, action=kaldi_common.StrToBoolAction,
default=True, choices=["true", "false"],
help="Use GPU or not.")
parser.add_argument("--gpu-id", type=str, default="",
help="If NULL, then it will be auto-specified.")
parser.add_argument("--benchmark", type=str, action=kaldi_common.StrToBoolAction,
default=True, choices=["true", "false"],
help="If true, save training time but require a little more gpu-memory.")
parser.add_argument("--run-lr-finder", type=str, action=kaldi_common.StrToBoolAction,
default=False, choices=["true", "false"],
help="If true, run lr finder rather than training.")
args = parser.parse_args()
##--------------------------------------------------##
## Control options
stage = max(0, args.stage)
endstage = min(4, args.endstage)
train_stage = max(-1, args.train_stage)
##--------------------------------------------------##
## Preprocess options
force_clear = args.force_clear
preprocess_nj = 20
compress = False
cmn = True # traditional cmn process
chunk_size = 200
limit_utts = 8
sample_type = "speaker_balance" # sequential | speaker_balance
chunk_num = 0 # -1 means using scale, 0 means using max and >0 means itself.
overlap = 0.1
scale = 1.5 # Get max / num_spks * scale for every speaker.
valid_split_type = "--total-spk" # --total-spk or --default
valid_utts = 1024
valid_chunk_num_every_utt = 2
##--------------------------------------------------##
## Training options
use_gpu = args.use_gpu # Default true.
benchmark = args.benchmark # If true, save much training time but require a little more gpu-memory.
gpu_id = args.gpu_id # If NULL, then it will be auto-specified.
run_lr_finder = args.run_lr_finder
egs_params = {
"aug":None, # None or specaugment
"aug_params":{"frequency":0.2, "frame":0.2}
}
loader_params = {
"use_fast_loader":True, # It is a queue loader to prefetch batch and storage.
"max_prefetch":10,
"batch_size":512,
"shuffle":True,
"num_workers":2,
"pin_memory":False,
"drop_last":True,
}
# Difine model_params by model_blueprint w.r.t your model's __init__(model_params).
model_params = {
"bn_momentum":0.99,
"nonlinearity":"relu",
"aug_dropout":0.2, # Should not be too large.
"training":True,
"extracted_embedding":"far" # For extracting. Here it is far or near w.r.t xvector.py.
}
optimizer_params = {
"name":"ralamb",
"learn_rate":0.001,
"beta1":0.9,
"beta2":0.999,
"beta3":0.999,
"weight_decay":1e-1, # Should be large for decouped weight decay (adamW) and small for L2 regularization (sgd, adam).
"lookahead.k":5,
"lookahead.alpha":0. # 0 means not using lookahead and if used, suggest to set it as 0.5.
}
lr_scheduler_params = {
"name":"warmR",
"warmR.lr_decay_step":400, # 0 means decay after every epoch and 1 means every iter.
"warmR.T_max":6,
"warmR.T_mult":1,
"warmR.factor":0.7, # The max_lr_decay_factor.
"warmR.eta_min":4e-8,
"warmR.log_decay":False
}
epochs = 18 # Total epochs to train. It is important. Here 18 = 6 -> 12 -> 18 with warmR.T_mult=1 and warmR.T_max=6.
report_times_every_epoch = None
report_interval_iters = 100 # About validation computation and loss reporting. If report_times_every_epoch is not None,
# then compute report_interval_iters by report_times_every_epoch.
stop_early = False
suffix = "params" # Used in saved model file.
##--------------------------------------------------##
## Other options
exist_model = "" # Use it in transfer learning.
##--------------------------------------------------##
## Main params
traindata = "data/mfcc_23_pitch/voxceleb1_train_aug"
egs_dir = "exp/egs/mfcc_23_pitch_voxceleb1_train_aug" + "_" + sample_type + "_max"
model_blueprint = "subtools/pytorch/model/xvector.py"
model_dir = "exp/standard_xv_baseline_warmR_voxceleb1"
##--------------------------------------------------##
##
#### Set seed
utils.set_all_seed(1024) # Note that, in different machine, random still will be different enven with the same seed,
# so, you could always get little different results by this launcher comparing to mine.
#### Preprocess
if stage <= 2 and endstage >= 0:
# Here only give limited options because it is not convenient.
# Suggest to pre-execute this shell script to make it freedom and then continue to run this launcher.
kaldi_common.execute_command("bash subtools/pytorch/pipeline/preprocess_to_egs.sh "
"--stage {stage} --endstage {endstage} --valid-split-type {valid_split_type} "
"--nj {nj} --cmn {cmn} --limit-utts {limit_utts} --min-chunk {chunk_size} --overlap {overlap} "
"--sample-type {sample_type} --chunk-num {chunk_num} --scale {scale} --force-clear {force_clear} "
"--valid-num-utts {valid_utts} --valid-chunk-num {valid_chunk_num_every_utt} --compress {compress} "
"{traindata} {egs_dir}".format(stage=stage, endstage=endstage, valid_split_type=valid_split_type,
nj=preprocess_nj, cmn=str(cmn).lower(), limit_utts=limit_utts, chunk_size=chunk_size, overlap=overlap,
sample_type=sample_type, chunk_num=chunk_num, scale=scale, force_clear=str(force_clear).lower(),
valid_utts=valid_utts, valid_chunk_num_every_utt=valid_chunk_num_every_utt, compress=str(compress).lower(),
traindata=traindata, egs_dir=egs_dir))
#### Train model
if stage <= 3 <= endstage:
logger.info("Get model_blueprint from model directory.")
# Save the raw model_blueprint in model_dir/config and get the copy of model_blueprint path.
model_blueprint = utils.create_model_dir(model_dir, model_blueprint, stage=train_stage)
logger.info("Load egs to bunch.")
# The dict [info] contains feat_dim and num_targets.
bunch, info = egs.BaseBunch.get_bunch_from_egsdir(egs_dir, egs_params, loader_params)
logger.info("Create model from model blueprint.")
# Another way: import the model.py in this python directly, but it is not friendly to the shell script of extracting and
# I don't want to change anything about extracting script when the model.py is changed.
model_py = utils.create_model_from_py(model_blueprint)
# Give your model class name here w.r.t the model.py.
model = model_py.Xvector(info["feat_dim"], info["num_targets"], **model_params)
logger.info("Define optimizer and lr_scheduler.")
optimizer = optim.get_optimizer(model, optimizer_params)
lr_scheduler = learn_rate_scheduler.LRSchedulerWrapper(optimizer, lr_scheduler_params)
# Record params to model_dir
utils.write_list_to_file([egs_params, loader_params, model_params, optimizer_params,
lr_scheduler_params], model_dir+'/config/params.dict')
logger.info("Init a simple trainer.")
# Package(Elements:dict, Params:dict}. It is a key parameter's package to trainer and model_dir/config/.
package = ({"data":bunch, "model":model, "optimizer":optimizer, "lr_scheduler":lr_scheduler},
{"model_dir":model_dir, "model_blueprint":model_blueprint, "exist_model":exist_model,
"start_epoch":train_stage, "epochs":epochs, "use_gpu":use_gpu, "gpu_id":gpu_id,
"benchmark":benchmark, "suffix":suffix, "report_times_every_epoch":report_times_every_epoch,
"report_interval_iters":report_interval_iters, "record_file":"train.csv"})
trainer = trainer.SimpleTrainer(package)
if run_lr_finder:
trainer.run_lr_finder("lr_finder.csv", init_lr=1e-8, final_lr=10., num_iters=2000, beta=0.98)
endstage = 3 # Do not start extractor.
else:
trainer.run()
# Plan to use del to avoid memeory account after training done and continue to execute stage 4.
# But it dose not work and is still a problem.
# Here, give the runLauncher.sh to avoid this problem.
# del bunch, model, optimizer, lr_scheduler, trainer
#### Extract xvector
if stage <= 4 <= endstage:
# There are some params for xvector extracting.
data_root = "data" # It contains all dataset just like Kaldi recipe.
prefix = "mfcc_23_pitch" # For to_extracted_data.
to_extracted_positions = ["far", "near"] # Define this w.r.t model_blueprint.
to_extracted_data = ["voxceleb1_train_aug", "voxceleb1_test"] # All dataset should be in dataroot/prefix.
to_extracted_epochs = ["6","12","18"] # It is model's name, such as 10.params or final.params (suffix is w.r.t package).
nj = 10
force = False
use_gpu = True
gpu_id = ""
sleep_time = 10
# Run a batch extracting process.
try:
for position in to_extracted_positions:
# Generate the extracting config from nnet config where
# which position to extract depends on the 'extracted_embedding' parameter of model_creation (by my design).
model_blueprint, model_creation = utils.read_nnet_config("{0}/config/nnet.config".format(model_dir))
model_creation = model_creation.replace("training=True", "training=False") # To save memory without loading some independent components.
model_creation = model_creation.replace(model_params["extracted_embedding"], position)
extract_config = "{0}.extract.config".format(position)
utils.write_nnet_config(model_blueprint, model_creation, "{0}/config/{1}".format(model_dir, extract_config))
for epoch in to_extracted_epochs:
model_file = "{0}.{1}".format(epoch, suffix)
point_name = "{0}_epoch_{1}".format(position, epoch)
# If run a trainer with background thread (do not be supported now) or run this launcher extrally with stage=4
# (it means another process), then this while-listen is useful to start extracting immediately (but require more gpu-memory).
model_path = "{0}/{1}".format(model_dir, model_file)
while True:
if os.path.exists(model_path):
break
else:
time.sleep(sleep_time)
for data in to_extracted_data:
datadir = "{0}/{1}/{2}".format(data_root, prefix, data)
outdir = "{0}/{1}/{2}".format(model_dir, point_name, data)
# Use a well-optimized shell script (with multi-processes) to extract xvectors.
# Another way: use subtools/splitDataByLength.sh and subtools/pytorch/pipeline/onestep/extract_embeddings.py
# with python's threads to extract xvectors directly, but the shell script is more convenient.
kaldi_common.execute_command("bash subtools/pytorch/pipeline/extract_xvectors_for_pytorch.sh "
"--model {model_file} --cmn {cmn} --nj {nj} --use-gpu {use_gpu} --gpu-id '{gpu_id}' "
" --force {force} --nnet-config config/{extract_config} "
"{model_dir} {datadir} {outdir}".format(model_file=model_file, cmn=str(cmn).lower(), nj=nj,
use_gpu=str(use_gpu).lower(), gpu_id=gpu_id, force=str(force).lower(), extract_config=extract_config,
model_dir=model_dir, datadir=datadir, outdir=outdir))
except BaseException as e:
if not isinstance(e, KeyboardInterrupt):
traceback.print_exc()
sys.exit(1)