forked from speechbrain/speechbrain
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_es_with_wav2vec.yaml
230 lines (194 loc) · 7.12 KB
/
train_es_with_wav2vec.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
# ################################
# Model: wav2vec2 + DNN + CTC
# Augmentation: SpecAugment
# Authors: Pooneh Mousavi 2023
# ################################
# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 1234
__set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
output_folder: !ref results/wav2vec2_ctc_es/<seed>
test_wer_file: !ref <output_folder>/wer_test.txt
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt
# URL for the biggest Fairseq multilingual
wav2vec2_hub: facebook/wav2vec2-large-xlsr-53
wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
# Data files
data_folder: !PLACEHOLDER # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
train_tsv_file: !ref <data_folder>/train.tsv # Standard CommonVoice .tsv files
dev_tsv_file: !ref <data_folder>/dev.tsv # Standard CommonVoice .tsv files
test_tsv_file: !ref <data_folder>/test.tsv # Standard CommonVoice .tsv files
accented_letters: True
language: es # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
train_csv: !ref <save_folder>/train.csv
valid_csv: !ref <save_folder>/dev.csv
test_csv: !ref <save_folder>/test.csv
skip_prep: False # Skip data preparation
# We remove utterance slonger than 10s in the train/dev/test sets as
# longer sentences certainly correspond to "open microphones".
avoid_if_longer_than: 10.0
####################### Training Parameters ####################################
number_of_epochs: 30
lr: 1.0
lr_wav2vec: 0.0001
sorting: ascending
precision: fp32 # bf16, fp16 or fp32
sample_rate: 16000
ckpt_interval_minutes: 30 # save checkpoint every N min
# With data_parallel batch_size is split into N jobs
# With DDP batch_size is multiplied by N jobs
# Must be 8 per GPU to fit 32GB of VRAM
dynamic_batching: False
batch_size: 12
test_batch_size: 4
dataloader_options:
batch_size: !ref <batch_size>
num_workers: 6
test_dataloader_options:
batch_size: !ref <test_batch_size>
num_workers: 6
# BPE parameters
token_type: unigram # ["unigram", "bpe", "char"]
character_coverage: 1.0
####################### Model Parameters #######################################
wav2vec_output_dim: 1024
dnn_neurons: 1024
freeze_wav2vec: False
freeze_feature_extractor: False
dropout: 0.15
warmup_steps: 500
# Outputs
output_neurons: 1000 # BPE size, index(blank/eos/bos) = 0
# Decoding parameters
# Be sure that the bos and eos index match with the BPEs ones
blank_index: 0
bos_index: 1
eos_index: 2
# Decoding parameters
# Be sure that the bos and eos index match with the BPEs ones
# Decoding parameters
test_beam_search:
blank_index: !ref <blank_index>
beam_size: 100
beam_prune_logp: -12.0
token_prune_min_logp: -1.2
prune_history: True
topk: 1
alpha: 1.0
beta: 0.5
# To use n-gram LM for decoding, follow steps in README.md.
# kenlm_model_path: none
#
# Functions and classes
#
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
limit: !ref <number_of_epochs>
############################## Augmentations ###################################
# Speed perturbation
speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
orig_freq: !ref <sample_rate>
speeds: [95, 100, 105]
# Frequency drop: randomly drops a number of frequency bands to zero.
drop_freq: !new:speechbrain.augment.time_domain.DropFreq
drop_freq_low: 0
drop_freq_high: 1
drop_freq_count_low: 1
drop_freq_count_high: 3
drop_freq_width: 0.05
# Time drop: randomly drops a number of temporal chunks.
drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
drop_length_low: 1000
drop_length_high: 2000
drop_count_low: 1
drop_count_high: 5
# Augmenter: Combines previously defined augmentations to perform data augmentation
wav_augment: !new:speechbrain.augment.augmenter.Augmenter
min_augmentations: 3
max_augmentations: 3
augment_prob: 1.0
augmentations: [
!ref <speed_perturb>,
!ref <drop_freq>,
!ref <drop_chunk>]
############################## Models ##########################################
enc: !new:speechbrain.nnet.containers.Sequential
input_shape: [null, null, !ref <wav2vec_output_dim>]
linear1: !name:speechbrain.nnet.linear.Linear
n_neurons: !ref <dnn_neurons>
bias: True
bn1: !name:speechbrain.nnet.normalization.BatchNorm1d
activation: !new:torch.nn.LeakyReLU
drop: !new:torch.nn.Dropout
p: !ref <dropout>
linear2: !name:speechbrain.nnet.linear.Linear
n_neurons: !ref <dnn_neurons>
bias: True
bn2: !name:speechbrain.nnet.normalization.BatchNorm1d
activation2: !new:torch.nn.LeakyReLU
drop2: !new:torch.nn.Dropout
p: !ref <dropout>
linear3: !name:speechbrain.nnet.linear.Linear
n_neurons: !ref <dnn_neurons>
bias: True
bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
activation3: !new:torch.nn.LeakyReLU
wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
source: !ref <wav2vec2_hub>
output_norm: True
freeze: !ref <freeze_wav2vec>
freeze_feature_extractor: !ref <freeze_feature_extractor>
save_path: !ref <wav2vec2_folder>
#####
# Uncomment this block if you prefer to use a Fairseq pretrained model instead
# of a HuggingFace one. Here, we provide an URL that is obtained from the
# Fairseq github for the multilingual XLSR.
#
#wav2vec2_url: https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr_53_56k.pt
#wav2vec2: !new:speechbrain.lobes.models.fairseq_wav2vec.FairseqWav2Vec2
# pretrained_path: !ref <wav2vec2_url>
# output_norm: True
# freeze: False
# save_path: !ref <save_folder>/wav2vec2_checkpoint/model.pt
#####
ctc_lin: !new:speechbrain.nnet.linear.Linear
input_size: !ref <dnn_neurons>
n_neurons: !ref <output_neurons>
log_softmax: !new:speechbrain.nnet.activations.Softmax
apply_log: True
ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
blank_index: !ref <blank_index>
modules:
wav2vec2: !ref <wav2vec2>
enc: !ref <enc>
ctc_lin: !ref <ctc_lin>
model: !new:torch.nn.ModuleList
- [!ref <enc>, !ref <ctc_lin>]
model_opt_class: !name:torch.optim.Adadelta
lr: !ref <lr>
rho: 0.95
eps: 1.e-8
wav2vec_opt_class: !name:torch.optim.Adam
lr: !ref <lr_wav2vec>
lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: !ref <lr>
improvement_threshold: 0.0025
annealing_factor: 0.8
patient: 0
lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: !ref <lr_wav2vec>
improvement_threshold: 0.0025
annealing_factor: 0.9
patient: 0
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
checkpoints_dir: !ref <save_folder>
recoverables:
wav2vec2: !ref <wav2vec2>
model: !ref <model>
scheduler_model: !ref <lr_annealing_model>
scheduler_wav2vec: !ref <lr_annealing_wav2vec>
counter: !ref <epoch_counter>
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
save_file: !ref <train_log>
error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
split_tokens: True