forked from facebookresearch/fairseq
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
wav2vec model (facebookresearch#654)
Summary: Merging wav2vec to master. Includes renames (Cpc -> wav2vec) and some light example files. Pull Request resolved: fairinternal/fairseq-py#654 Differential Revision: D15913409 Pulled By: alexeib fbshipit-source-id: f723e6f211706cd9431c7d76dc12c4e80c9cfc80
- Loading branch information
1 parent
c55eda5
commit 94c24fe
Showing
12 changed files
with
1,091 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# wav2vec | ||
|
||
Example to train a wav2vec model as described in [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](https://arxiv.org/abs/1904.05862). | ||
|
||
## Training a new model with the CLI tools | ||
|
||
Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate file 10 to 30 seconds in length) | ||
|
||
### Prepare training data manifest: | ||
|
||
``` | ||
$ python scripts/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext wav | ||
``` | ||
|
||
### Train a wav2vec model: | ||
|
||
``` | ||
$ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 --save-interval 1 --no-epoch-checkpoints \ | ||
--arch wav2vec --task audio_pretraining --lr 1e-06 --min-lr 1e-09 --optimizer adam --max-lr 0.005 --lr-scheduler cosine \ | ||
--conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1)] \ | ||
--conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \ | ||
--skip-connections-agg --residual-scale 0.5 --log-compression --warmup-updates 500 --warmup-init-lr 1e-07 --criterion binary_cross_entropy --num-negatives 10 \ | ||
--max-sample-size 150000 --max-tokens 1500000 ---skip-invalid-size-inputs-valid-test | ||
``` | ||
|
||
### Extract embeddings from the downstream task data: | ||
|
||
``` | ||
$ PYTHONPATH /path/to/fairseq python scripts/wav2vec_featurize.py --input /path/to/task/waves --output /path/to/output \ | ||
--model /model/path/checkpoint_best.pt --split train valid test | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# Copyright (c) 2017-present, Facebook, Inc. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the LICENSE file in | ||
# the root directory of this source tree. An additional grant of patent rights | ||
# can be found in the PATENTS file in the same directory. | ||
|
||
import math | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
from fairseq import utils | ||
|
||
from . import FairseqCriterion, register_criterion | ||
|
||
|
||
@register_criterion('binary_cross_entropy') | ||
class BinaryCrossEntropyCriterion(FairseqCriterion): | ||
|
||
def __init__(self, args, task): | ||
super().__init__(args, task) | ||
|
||
def forward(self, model, sample, reduce=True): | ||
"""Compute the loss for the given sample. | ||
Returns a tuple with three elements: | ||
1) the loss | ||
2) the sample size, which is used as the denominator for the gradient | ||
3) logging outputs to display while training | ||
""" | ||
net_output = model(**sample['net_input']) | ||
logits = model.get_logits(net_output).float() | ||
target = model.get_targets(sample, net_output, expand_steps=False).float() | ||
|
||
if hasattr(model, 'get_target_weights'): | ||
weights = model.get_target_weights(target, net_output) | ||
if torch.is_tensor(weights): | ||
weights = weights.float() | ||
else: | ||
weights = 1. | ||
|
||
loss = F.binary_cross_entropy_with_logits(logits, target, reduce=False) | ||
|
||
loss = loss * weights | ||
|
||
if reduce: | ||
loss = loss.sum() | ||
|
||
sample_size = target.numel() | ||
logging_output = { | ||
'loss': utils.item(loss.data) if reduce else loss.data, | ||
'ntokens': sample_size, | ||
'nsentences': logits.size(0), | ||
'sample_size': sample_size, | ||
} | ||
return loss, sample_size, logging_output | ||
|
||
@staticmethod | ||
def aggregate_logging_outputs(logging_outputs): | ||
"""Aggregate logging outputs from data parallel training.""" | ||
loss_sum = sum(log.get('loss', 0) for log in logging_outputs) | ||
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) | ||
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) | ||
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) | ||
agg_output = { | ||
'loss': loss_sum / sample_size / math.log(2), | ||
'ntokens': ntokens, | ||
'nsentences': nsentences, | ||
'sample_size': sample_size, | ||
} | ||
if sample_size != ntokens: | ||
agg_output['nll_loss'] = loss_sum / ntokens / math.log(2) | ||
return agg_output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# Copyright (c) 2017-present, Facebook, Inc. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the LICENSE file in | ||
# the root directory of this source tree. An additional grant of patent rights | ||
# can be found in the PATENTS file in the same directory. | ||
|
||
|
||
import os | ||
import numpy as np | ||
import sys | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
from .. import FairseqDataset | ||
|
||
|
||
class RawAudioDataset(FairseqDataset): | ||
|
||
def __init__(self, manifest_path, sample_rate, max_sample_size=None, min_sample_size=None, | ||
shuffle=True): | ||
super().__init__() | ||
|
||
self.sample_rate = sample_rate | ||
self.fnames = [] | ||
self.sizes = [] | ||
self.max_sample_size = max_sample_size if max_sample_size is not None else sys.maxsize | ||
self.min_sample_size = min_sample_size if min_sample_size is not None else self.max_sample_size | ||
|
||
with open(manifest_path, 'r') as f: | ||
self.root_dir = f.readline().strip() | ||
for line in f: | ||
items = line.strip().split('\t') | ||
assert len(items) == 2, line | ||
self.fnames.append(items[0]) | ||
self.sizes.append(int(items[1])) | ||
self.shuffle = shuffle | ||
|
||
def __getitem__(self, index): | ||
fname = os.path.join(self.root_dir, self.fnames[index]) | ||
import soundfile as sf | ||
|
||
wav, curr_sample_rate = sf.read(fname) | ||
feats = torch.from_numpy(wav).float() | ||
|
||
if feats.dim() == 2: | ||
feats = feats.mean(-1) | ||
|
||
if curr_sample_rate != self.sample_rate: | ||
factor = self.sample_rate / curr_sample_rate | ||
feats = self.resample(feats, factor) | ||
|
||
assert feats.dim() == 1, feats.dim() | ||
|
||
return { | ||
'id': index, | ||
'source': feats, | ||
} | ||
|
||
def resample(self, x, factor): | ||
return F.interpolate(x.view(1, 1, -1), scale_factor=factor).squeeze() | ||
|
||
def __len__(self): | ||
return len(self.fnames) | ||
|
||
def collater(self, samples): | ||
if len(samples) == 0: | ||
return {} | ||
|
||
sources = [s['source'] for s in samples] | ||
sizes = [len(s) for s in sources] | ||
target_size = min(min(sizes), self.max_sample_size) | ||
|
||
if self.min_sample_size < target_size: | ||
target_size = np.random.randint(self.min_sample_size, target_size + 1) | ||
|
||
collated_sources = sources[0].new(len(sources), target_size) | ||
for i, (source, size) in enumerate(zip(sources, sizes)): | ||
diff = size - target_size | ||
assert diff >= 0 | ||
if diff == 0: | ||
collated_sources[i] = source | ||
else: | ||
start = np.random.randint(0, diff + 1) | ||
end = size - diff + start | ||
collated_sources[i] = source[start:end] | ||
|
||
return { | ||
'id': torch.LongTensor([s['id'] for s in samples]), | ||
'net_input': { | ||
'source': collated_sources, | ||
}, | ||
} | ||
|
||
def get_dummy_batch( | ||
self, num_tokens, max_positions, src_len=2048, tgt_len=128, | ||
): | ||
"""Return a dummy batch with a given number of tokens.""" | ||
if isinstance(max_positions, float) or isinstance(max_positions, int): | ||
src_len = min(src_len, max_positions) | ||
bsz = num_tokens // src_len | ||
return self.collater([ | ||
{ | ||
'id': i, | ||
'source': torch.rand(src_len), | ||
} | ||
for i in range(bsz) | ||
]) | ||
|
||
def num_tokens(self, index): | ||
return self.size(index) | ||
|
||
def size(self, index): | ||
"""Return an example's size as a float or tuple. This value is used when | ||
filtering a dataset with ``--max-positions``.""" | ||
return min(self.sizes[index], self.max_sample_size) | ||
|
||
def ordered_indices(self): | ||
"""Return an ordered list of indices. Batches will be constructed based | ||
on this order.""" | ||
|
||
if self.shuffle: | ||
order = [np.random.permutation(len(self))] | ||
else: | ||
order = [np.arange(len(self))] | ||
|
||
order.append(self.sizes) | ||
return np.lexsort(order) |
Oops, something went wrong.