forked from speechbrain/speechbrain
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_hf_wav2vec_full.yaml
177 lines (144 loc) · 5.64 KB
/
train_hf_wav2vec_full.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
# ################################
# Model: Wav2Vec + DNN + CTC + Softmax
# Authors:
# Gaëlle Laperrière 2023
# ################################
# ------ Paths and parameters
# Seed needs to be set at top of yaml, before objects with parameters are made.
seed: 4242
__set_seed: !apply:torch.manual_seed [!ref <seed>]
output_folder: !ref results/media_SLU_wav2vec_full/<seed>
cer_file_test: !ref <output_folder>/cer_test.txt
ctc_file_test: !ref <output_folder>/ctc_test.txt
coer_file_test: !ref <output_folder>/coer_test.txt
cver_file_test: !ref <output_folder>/cver_test.txt
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt
data_folder: !PLACEHOLDER # Path of folders S0272 and E0024, to process ELRA original xml datasets.
channels_path: !PLACEHOLDER # Path of the channels.csv file downloaded via https://www.dropbox.com/sh/y7ab0lktbylz647/AADMsowYHmNYwaoL_hQt7NMha?dl=0
concepts_path: !PLACEHOLDER # Path of the concepts_full_relax.csv file downloaded via https://www.dropbox.com/sh/y7ab0lktbylz647/AADMsowYHmNYwaoL_hQt7NMha?dl=0
skip_wav: False # Skip the wav files storing if already done before.
method: full # Remove or keep specifiers in concepts. Method used by default.
task: slu # Parse SLU or ASR data.
skip_prep: False # Skip data preparation to csv because already done.
process_test2: False # Process the test2 corpus
# See https://github.com/pytorch/fairseq/blob/main/examples/wav2vec/README.md
# for Wav2Vec models and https://huggingface.co/LeBenchmark for French ones.
wav2vec_url: LeBenchmark/wav2vec2-FR-3K-large
# Data files:
csv_train: !ref <save_folder>/csv/train.csv
csv_valid: !ref <save_folder>/csv/dev.csv
csv_test: !ref <save_folder>/csv/test.csv # If the test2 was processed, you can change the file to test2.csv
# Data parameters:
# With data_parallel batch_size is split into N jobs.
# With DDP batch_size is multiplied by N jobs.
batch_size: 4
test_batch_size: 2
# We remove utterances longer than 90s in the train/dev/test sets as
# longer sentences certainly correspond to "open microphones".
avoid_if_longer_than: 90.0
avoid_if_smaller_than: 0.0
num_workers: 3
dataloader_options:
batch_size: !ref <batch_size>
num_workers: !ref <num_workers>
shuffle: True
test_dataloader_options:
batch_size: !ref <test_batch_size>
num_workers: !ref <num_workers>
# Feature parameters:
sample_rate: 16000
feats_dim: 1024
####################### Training Parameters ####################################:
number_of_epochs: 30
lr: 1
lr_wav2vec: 0.0001
annealing_factor: 0.8
annealing_factor_wav2vec: 0.9
improvement_threshold: 0.0025
improvement_threshold_wav2vec: 0.0025
patient: 0
patient_wav2vec: 0
sorting: ascending
####################### Model Parameters #######################################
activation: !name:torch.nn.LeakyReLU
dnn_blocks: 3
dnn_neurons: 512
# Wav2Vec parameters:
freeze: False
# Decoding parameters:
blank_index: 0
# Outputs:
output_neurons: 212
# ------ Functions and classes
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
limit: !ref <number_of_epochs>
wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
source: !ref <wav2vec_url>
output_norm: True
freeze: !ref <freeze>
save_path: !ref <save_folder>/wav2vec2_checkpoint
enc: !new:speechbrain.lobes.models.VanillaNN.VanillaNN
input_shape: [null, null, !ref <feats_dim>]
activation: !ref <activation>
dnn_blocks: !ref <dnn_blocks>
dnn_neurons: !ref <dnn_neurons>
output_lin: !new:speechbrain.nnet.linear.Linear
input_size: !ref <dnn_neurons>
n_neurons: !ref <output_neurons>
bias: True
log_softmax: !new:speechbrain.nnet.activations.Softmax
apply_log: True
ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
blank_index: !ref <blank_index>
modules:
wav2vec2: !ref <wav2vec2>
enc: !ref <enc>
output_lin: !ref <output_lin>
model: !new:torch.nn.ModuleList
- [!ref <enc>, !ref <output_lin>]
model_wav2vec2: !new:torch.nn.ModuleList
- [!ref <wav2vec2>]
opt_class: !name:torch.optim.Adadelta
lr: !ref <lr>
rho: 0.95
eps: 1.e-8
opt_class_wav2vec: !name:torch.optim.Adam
lr: !ref <lr_wav2vec>
lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: !ref <lr>
improvement_threshold: !ref <improvement_threshold>
annealing_factor: !ref <annealing_factor>
patient: !ref <patient>
lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: !ref <lr_wav2vec>
improvement_threshold: !ref <improvement_threshold_wav2vec>
annealing_factor: !ref <annealing_factor_wav2vec>
patient: !ref <patient_wav2vec>
label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
checkpoints_dir: !ref <save_folder>
recoverables:
model: !ref <model>
wav2vec2: !ref <wav2vec2>
lr_annealing: !ref <lr_annealing>
lr_annealing_wav2vec: !ref <lr_annealing_wav2vec>
counter: !ref <epoch_counter>
tokenizer: !ref <label_encoder>
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
save_file: !ref <train_log>
ctc_computer: !name:speechbrain.utils.metric_stats.MetricStats
metric: !name:speechbrain.nnet.losses.ctc_loss
blank_index: !ref <blank_index>
reduction: batch
cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
coer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
extract_concepts_values: True
keep_values: False
tag_in: '<'
tag_out: '>'
cver_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
extract_concepts_values: True
keep_values: True
tag_in: '<'
tag_out: '>'