wav2vec 2.0 learns speech representations on unlabeled data as described in wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020).
We learned speech representations in multiple languages as well in Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020).
We also combined wav2vec 2.0 with self-training in Self-training and Pre-training are Complementary for Speech Recognition (Xu et al., 2020).
Model | Finetuning split | Dataset | Model |
---|---|---|---|
Wav2Vec 2.0 Base | No finetuning | Librispeech | download |
Wav2Vec 2.0 Base | 10 minutes | Librispeech | download |
Wav2Vec 2.0 Base | 100 hours | Librispeech | download |
Wav2Vec 2.0 Base | 960 hours | Librispeech | download |
Wav2Vec 2.0 Large | No finetuning | Librispeech | download |
Wav2Vec 2.0 Large | 10 minutes | Librispeech | download |
Wav2Vec 2.0 Large | 100 hours | Librispeech | download |
Wav2Vec 2.0 Large | 960 hours | Librispeech | download |
Wav2Vec 2.0 Large (LV-60)* | No finetuning | Libri-Light | download |
Wav2Vec 2.0 Large (LV-60)* | 10 minutes | Libri-Light + Librispeech | download |
Wav2Vec 2.0 Large (LV-60)* | 100 hours | Libri-Light + Librispeech | download |
Wav2Vec 2.0 Large (LV-60)* | 960 hours | Libri-Light + Librispeech | download |
Wav2Vec 2.0 Large (LV-60) + Self Training * | 10 minutes | Libri-Light + Librispeech | download |
Wav2Vec 2.0 Large (LV-60) + Self Training * | 100 hours | Libri-Light + Librispeech | download |
Wav2Vec 2.0 Large (LV-60) + Self Training * | 960 hours | Libri-Light + Librispeech | download |
* updated (Oct. 24, 2020)
We also release multilingual pre-trained wav2vec 2.0 (XLSR) models:
Model | Architecture | Hours | Languages | Datasets | Model |
---|---|---|---|---|---|
XLSR-53 | Large | 56k | 53 | MLS, CommonVoice, BABEL | download |
The XLSR model uses the following datasets for multilingual pretraining:
-
MLS: Multilingual LibriSpeech (8 languages, 50.7k hours): Dutch, English, French, German, Italian, Polish, Portuguese, Spanish
-
CommonVoice (36 languages, 3.6k hours): Arabic, Basque, Breton, Chinese (CN), Chinese (HK), Chinese (TW), Chuvash, Dhivehi, Dutch, English, Esperanto, Estonian, French, German, Hakh-Chin, Indonesian, Interlingua, Irish, Italian, Japanese, Kabyle, Kinyarwanda, Kyrgyz, Latvian, Mongolian, Persian, Portuguese, Russian, Sakha, Slovenian, Spanish, Swedish, Tamil, Tatar, Turkish, Welsh (see also finetuning splits from this paper).
-
Babel (17 languages, 1.7k hours): Assamese, Bengali, Cantonese, Cebuano, Georgian, Haitian, Kazakh, Kurmanji, Lao, Pashto, Swahili, Tagalog, Tamil, Tok, Turkish, Vietnamese, Zulu
Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate file 10 to 30 seconds in length)
First, install the soundfile
library:
pip install soundfile
Next, run:
$ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext $ext --valid-percent $valid
$ext should be set to flac, wav, or whatever format your dataset happens to use that soundfile can read.
$valid should be set to some reasonable percentage (like 0.01) of training data to use for validation. To use a pre-defined validation set (like dev-other from librispeech), set to it 0 and then overwrite valid.tsv with a separately pre-processed manifest file.
This configuration was used for the base model trained on the Librispeech dataset in the wav2vec 2.0 paper
Note that the input is expected to be single channel, sampled at 16 kHz
$ fairseq-hydra-train \
task.data=/path/to/data \
--config-dir /path/to/fairseq-py/examples/wav2vec/config/pretraining \
--config-name wav2vec2_base_librispeech
Note: you can simulate 64 GPUs by using k GPUs and adding command line parameters (before --config-dir
)
distributed_training.distributed_world_size=k
+optimization.update_freq='[x]'
where x = 64/k
This configuration was used for the large model trained on the Libri-light dataset in the wav2vec 2.0 paper
$ fairseq-hydra-train \
task.data=/path/to/data \
--config-dir /path/to/fairseq-py/examples/wav2vec/config/pretraining \
--config-name wav2vec2_large_librivox
Note: you can simulate 128 GPUs by using k GPUs and adding command line parameters (before --config-dir
)
distributed_training.distributed_world_size=k
+optimization.update_freq='[x]'
where x = 128/k
Fine-tuning a model requires parallel audio and labels file, as well as a vocabulary file in fairseq format. A letter vocabulary can be downloaded here. An example script that generates labels for the Librispeech dataset from the tsv file produced by wav2vec_manifest.py can be used as follows:
split=train
$ python libri_labels.py /path/to/tsv --output-dir /output/dir --output-name $split
Fine-tuning on 100h of Librispeech with letter targets:
$ fairseq-hydra-train \
distributed_training.distributed_port=$PORT \
task.data=/path/to/data \
model.w2v_path=/path/to/model.pt \
--config-dir /path/to/fairseq-py/examples/wav2vec/config/finetuning \
--config-name base_100h
There are other config files in the config/finetuning directory that can be used to fine-tune on other splits.
You can specify the right config via the --config-name
parameter.
Note: you can simulate 24 GPUs by using k GPUs and adding command line parameters (before --config-dir
)
distributed_training.distributed_world_size=k
+optimization.update_freq='[x]'
where x = 24/k
Decoding with a language model during training requires wav2letter python bindings.
If you want to use a language model, add +criterion.wer_args='[/path/to/kenlm, /path/to/lexicon, 2, -1]'
to the command line.
Evaluating a CTC model with a language model requires wav2letter python bindings to be installed.
Fairseq transformer language model used in the wav2vec 2.0 paper can be obtained from the wav2letter model repository. Be sure to upper-case the language model vocab after downloading it.
Letter dictionary for pre-trained models can be found here.
Next, run the evaluation command:
$subset=dev_other
python examples/speech_recognition/infer.py /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw --task audio_pretraining \
--nbest 1 --path /path/to/model --gen-subset $subset --results-path /path/to/save/results/for/sclite --w2l-decoder kenlm \
--lm-model /path/to/kenlm.bin --lm-weight 2 --word-score -1 --sil-weight 0 --criterion ctc --labels ltr --max-tokens 4000000 \
--post-process letter
To get raw numbers, use --w2l-decoder viterbi and omit the lexicon. To use the transformer language model, use --w2l-decoder fairseqlm.
Example to train a wav2vec model as described in wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019).
Description | Dataset | Model |
---|---|---|
Wav2Vec large | Librispeech | download |
import torch
import fairseq
cp = torch.load('/path/to/wav2vec.pt')
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp])
model = model[0]
model.eval()
wav_input_16khz = torch.randn(1,10000)
z = model.feature_extractor(wav_input_16khz)
c = model.feature_aggregator(z)
Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate files 10 to 30 seconds in length)
$ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext wav
$ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 --save-interval 1 --no-epoch-checkpoints \
--arch wav2vec --task audio_pretraining --min-lr 1e-06 --stop-min-lr 1e-09 --optimizer adam --lr 0.005 --lr-scheduler cosine \
--conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1)] \
--conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \
--skip-connections-agg --residual-scale 0.5 --log-compression --warmup-updates 500 --warmup-init-lr 1e-07 --criterion wav2vec --num-negatives 10 \
--max-sample-size 150000 --max-tokens 1500000 --skip-invalid-size-inputs-valid-test
$ PYTHONPATH=/path/to/fairseq python examples/wav2vec/wav2vec_featurize.py --input /path/to/task/waves --output /path/to/output \
--model /model/path/checkpoint_best.pt --split train valid test
Example to train a vq-wav2vec model as described in vq-wav2vec: Self-Supervised Learning of Discrete Speech Representations (Baevski et al., 2019).
These models are also used in Effectiveness of self-supervised pre-training for speech recognition (Baevski et al., 2019).
Description | Dataset | Model |
---|---|---|
vq-wav2vec Gumbel | Librispeech | download |
vq-wav2vec K-means | Librispeech | download |
Roberta on K-means codes | Librispeech | download |
import torch
import fairseq
cp = torch.load('/path/to/vq-wav2vec.pt')
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp])
model = model[0]
model.eval()
wav_input_16khz = torch.randn(1,10000)
z = model.feature_extractor(wav_input_16khz)
_, idxs = model.vector_quantizer.forward_idx(z)
print(idxs.shape) # output: torch.Size([1, 60, 2]), 60 timesteps with 2 indexes corresponding to 2 groups in the model
Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate file 10 to 30 seconds in length)
$ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext wav
$ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 \
--save-interval 1 --no-epoch-checkpoints --arch wav2vec --task audio_pretraining --min-lr 1e-06 --stop-min-lr 1e-09 \
--optimizer adam --lr 1e-05 --lr-scheduler cosine \
--conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1), (512, 1, 1)] \
--conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \
--activation gelu --offset auto --skip-connections-agg --residual-scale 0.5 \
--log-keys ["prob_perplexity","code_perplexity","temp"] --vq-type gumbel --vq-groups 2 --vq-depth 2 \
--combine-groups --vq-vars 320 --vq-temp (2,0.5,0.999995) --prediction-steps 12 --warmup-updates 1000 \
--warmup-init-lr 1e-07 --criterion wav2vec --num-negatives 10 --max-sample-size 150000 \
--max-tokens 300000 --cross-sample-negatives 0 --update-freq 1 --seed 2 --skip-invalid-size-inputs-valid-test
for k-means training, set vq-type with "kmeans" and add --loss-weights [1] argument. Pre-trained models were trained on 16 GPUs.
$ PYTHONPATH=/path/to/fairseq python examples/wav2vec/vq-wav2vec_featurize.py --data-dir /manifest/path --output-dir /path/to/output \
--checkpoint /model/path/checkpoint_best.pt --split train valid test --extension tsv