Skip to content

Commit

Permalink
[egs] LibriCSS recipe (#4321)
Browse files Browse the repository at this point in the history
Refer to the README.md to each eg directory for description.
  • Loading branch information
desh2608 authored Nov 13, 2020
1 parent 882b0a6 commit 1670662
Show file tree
Hide file tree
Showing 75 changed files with 5,787 additions and 97 deletions.
89 changes: 12 additions & 77 deletions egs/callhome_diarization/v1/diarization/vb_hmm_xvector.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/usr/bin/env python
#!/usr/bin/env python3
# Copyright 2020 Johns Hopkins University (Author: Desh Raj)
# Apache 2.0

Expand All @@ -9,7 +9,7 @@
# vb_hmm_xvector.sh which can divide all labels into per recording
# labels.

import sys, argparse, struct, re
import sys, argparse, struct
import numpy as np
import itertools
import kaldi_io
Expand All @@ -36,9 +36,6 @@ def get_args():
help="scale sufficient statistics collected using UBM")
parser.add_argument("--fb", type=float, default=11,
help="speaker regularization coefficient Fb (controls final # of speaker)")
parser.add_argument("--overlap-rttm", type=str,
help="path to an RTTM file containing overlap segments. If provided,"
"multiple speaker labels will be allocated to these segments.")
parser.add_argument("xvector_ark_file", type=str,
help="Ark file containing xvectors for all subsegments")
parser.add_argument("plda", type=str,
Expand All @@ -61,77 +58,25 @@ def read_labels_file(label_file):
return segments, labels

def write_labels_file(seg2label, out_file):
f = open(out_file, 'w')
for seg in sorted(seg2label.keys()):
label = seg2label[seg]
if type(label) is tuple:
f.write("{} {}\n".format(seg, label[0]))
f.write("{} {}\n".format(seg, label[1]))
else:
f.write("{} {}\n".format(seg, label))
f.close()
with open(out_file, 'w') as f:
for seg in sorted(seg2label.keys()):
label = seg2label[seg]
f.write(f"{seg} {label}\n")
return

def get_overlap_decision(overlap_segs, subsegment, frac = 0.5):
""" Returns true if at least 'frac' fraction of the subsegment lies
in the overlap_segs."""
start_time = subsegment[0]
end_time = subsegment[1]
dur = end_time - start_time
total_ovl = 0

for seg in overlap_segs:
cur_start, cur_end = seg
if (cur_start >= end_time):
break
ovl_start = max(start_time, cur_start)
ovl_end = min(end_time, cur_end)
ovl_time = max(0, ovl_end-ovl_start)

total_ovl += ovl_time

return (total_ovl >= frac * dur)


def get_overlap_vector(overlap_rttm, segments):
reco_id = '_'.join(segments[0].split('_')[:3])
overlap_segs = []
with open(overlap_rttm, 'r') as f:
for line in f.readlines():
parts = line.strip().split()
if (parts[1] == reco_id):
overlap_segs.append((float(parts[3]), float(parts[3]) + float(parts[4])))
ol_vec = np.zeros(len(segments))
overlap_segs.sort(key=lambda x: x[0])
for i, segment in enumerate(segments):
parts = re.split('_|-',segment)
start_time = (float(parts[3]) + float(parts[5]))/100
end_time = (float(parts[3]) + float(parts[6]))/100

is_overlap = get_overlap_decision(overlap_segs, (start_time, end_time))
if is_overlap:
ol_vec[i] = 1
print ("{}: {} fraction of segments are overlapping".format(id, ol_vec.sum()/len(ol_vec)))
return ol_vec

def read_args(args):
segments, labels = read_labels_file(args.input_label_file)
xvec_all = dict(kaldi_io.read_vec_flt_ark(args.xvector_ark_file))
xvectors = []
for segment in segments:
xvectors.append(xvec_all[segment])
_, _, plda_psi = kaldi_io.read_plda(args.plda)
if (args.overlap_rttm is not None):
print('Getting overlap segments...')
overlaps = get_overlap_vector(args.overlap_rttm, segments)
else:
overlaps = None
return xvectors, segments, labels, plda_psi, overlaps
return xvectors, segments, labels, plda_psi


###################################################################

def vb_hmm(segments, in_labels, xvectors, overlaps, plda_psi, init_smoothing, loop_prob, fa, fb):
def vb_hmm(segments, in_labels, xvectors, plda_psi, init_smoothing, loop_prob, fa, fb):
x = np.array(xvectors)
dim = x.shape[1]

Expand All @@ -153,25 +98,15 @@ def vb_hmm(segments, in_labels, xvectors, overlaps, plda_psi, init_smoothing, lo
gamma=q_init, maxSpeakers=q_init.shape[1], maxIters=40, epsilon=1e-6, loopProb=loop_prob,
Fa=fa, Fb=fb)

labels = np.argsort(q, axis=1)[:,[-1,-2]]
labels = np.unique(q.argmax(1), return_inverse=True)[1]

if overlaps is not None:
final_labels = []
for i in range(len(overlaps)):
if (overlaps[i] == 1):
final_labels.append((labels[i,0], labels[i,1]))
else:
final_labels.append(labels[i,0])
else:
final_labels = labels[:,0]

return {seg:label for seg,label in zip(segments,final_labels)}
return {seg:label for seg,label in zip(segments,labels)}

def main():
args = get_args()
xvectors, segments, labels, plda_psi, overlaps = read_args(args)
xvectors, segments, labels, plda_psi = read_args(args)

seg2label_vb = vb_hmm(segments, labels, xvectors, overlaps, plda_psi, args.init_smoothing,
seg2label_vb = vb_hmm(segments, labels, xvectors, plda_psi, args.init_smoothing,
args.loop_prob, args.fa, args.fb)
write_labels_file(seg2label_vb, args.output_label_file)

Expand Down
23 changes: 3 additions & 20 deletions egs/callhome_diarization/v1/diarization/vb_hmm_xvector.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ stage=0
nj=10
cleanup=true
rttm_channel=0
overlap_rttm= # Path to an RTTM output of an external overlap detector

# The hyperparameters used here are taken from the DIHARD
# optimal hyperparameter values reported in:
Expand Down Expand Up @@ -69,14 +68,6 @@ if [ "$result" == "0" ]; then
python3 -m pip install numexpr
fi

overlap_rttm_opt=
if ! [ -z "$overlap_rttm" ]; then
overlap_rttm_opt="--overlap-rttm $overlap_rttm"
rttm_bin="make_rttm_ol.py"
else
rttm_bin="make_rttm.py"
fi

if [ $stage -le 0 ]; then
# Mean subtraction (If original x-vectors are high-dim, e.g. 512, you should
# consider also applying LDA to reduce dimensionality to, say, 200)
Expand All @@ -85,18 +76,10 @@ if [ $stage -le 0 ]; then
fi

echo -e "Performing bayesian HMM based x-vector clustering..\n"
# making a shell script for each job
for n in `seq $nj`; do
cat <<-EOF > $dir/tmp/vb_hmm.$n.sh
python3 diarization/vb_hmm_xvector.py $overlap_rttm_opt \
--loop-prob $loop_prob --fa $fa --fb $fb \
$xvec_dir/xvector_norm.ark $plda $dir/labels.$n $dir/labels.vb.$n
EOF
done

chmod a+x $dir/tmp/vb_hmm.*.sh
$cmd JOB=1:$nj $dir/log/vb_hmm.JOB.log \
$dir/tmp/vb_hmm.JOB.sh
diarization/vb_hmm_xvector.py \
--loop-prob $loop_prob --fa $fa --fb $fb \
$xvec_dir/xvector_norm.ark $plda $dir/labels.JOB $dir/labels.vb.JOB

if [ $stage -le 1 ]; then
echo "$0: combining labels"
Expand Down
63 changes: 63 additions & 0 deletions egs/libri_css/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
### LibriCSS integrated recipe

This is a Kaldi recipe for the LibriCSS data, providing diarization and
ASR on mixed single-channel and separated audio inputs.

#### Data
We use the LibriCSS data released with the following paper:
```
@article{Chen2020ContinuousSS,
title={Continuous Speech Separation: Dataset and Analysis},
author={Z. Chen and T. Yoshioka and Liang Lu and T. Zhou and Zhong Meng and Yi Luo and J. Wu and J. Li},
journal={ICASSP 2020 - 2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
year={2020}
}
```
For the official data and code, check out [the official repo](https://github.com/chenzhuo1011/libri_css).

#### Recipe details
This recipe addresses the problem of speech recognition in a meeting-like
scenario, where multiple overlapping speakers may be present, and the
number of speakers is not known beforehand.

We provide recipes for 2 scenarios:
1. `s5_mono`: This is a single channel diarization + ASR recipe which takes as the
input a long single-channel recording containing mixed audio. It then performs SAD,
diarization, and ASR on it and outputs speaker-attributed transcriptions,
which are then evaluated with cpWER (similar to CHiME6 Track 2).
2. `s5_css`: This pipeline uses a speech separation module at the beginning,
so the input is 2-3 separated audio streams. We assume that the separation is
window-based, so that the same speaker may be split across different streams in
different windows, thus making diarization necessary.

#### Pretrained models for diarization and ASR
For ease of reproduction, we include the training for both modules in the
recipe. We also provide pretrained models for both diarization and ASR
systems.

* SAD: CHiME-6 baseline TDNN-Stats SAD available [here](http://kaldi-asr.org/models/m12).
* Speaker diarization: CHiME-6 baseline x-vector + AHC diarizer, trained on VoxCeleb
with simulated RIRs available [here](http://kaldi-asr.org/models/m12).
* ASR: We used the chain model trained on 960h clean LibriSpeech training data available
[here](http://kaldi-asr.org/models/m13). It was then additionally fine-tuned for 1
epoch on LibriSpeech + simulated RIRs. For LM, we trained a TDNN-LSTM language model
for rescoring. All of these models are available at this
[Google Drive link](https://drive.google.com/file/d/13ceXdK6oAUuUyxn7kjQVVqpe8r6Sc7ds/view?usp=sharing).

#### Speech separation
The speech separation module has not been provided. If you want to use the
`s5_css` recipe, check out [this tutorial](https://desh2608.github.io/pages/jsalt/) for
instructions on how to plug in your component into the pipeline.

If you found this recipe useful for your experiments, consider citing:

```
@article{Raj2021Integration,
title={Integration of speech separation, diarization, and recognition for multi-speaker meetings:
System description, Comparison, and Analysis},
author={D.Raj and P.Denisov and Z.Chen and H.Erdogan and Z.Huang and M.He and S.Watanabe and
J.Du and T.Yoshioka and Y.Luo and N.Kanda and J.Li and S.Wisdom and J.Hershey},
journal={IEEE Spoken Language Technology Workshop 2021},
year={2021}
}
```
14 changes: 14 additions & 0 deletions egs/libri_css/s5_css/cmd.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# you can change cmd.sh depending on what type of queue you are using.
# If you have no queueing system and want to run on a local machine, you
# can change all instances 'queue.pl' to run.pl (but be careful and run
# commands one by one: most recipes will exhaust the memory on your
# machine). queue.pl works with GridEngine (qsub). slurm.pl works
# with slurm. Different queues are configured differently, with different
# queue names and different ways of specifying things like memory;
# to account for these differences you can create and edit the file
# conf/queue.conf to match your queue's configuration. Search for
# conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information,
# or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl.

export train_cmd="retry.pl queue.pl --mem 2G"
export decode_cmd="queue.pl --mem 4G"
2 changes: 2 additions & 0 deletions egs/libri_css/s5_css/conf/mfcc.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
--use-energy=false
--sample-frequency=16000
10 changes: 10 additions & 0 deletions egs/libri_css/s5_css/conf/mfcc_hires.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# config for high-resolution MFCC features, intended for neural network training.
# Note: we keep all cepstra, so it has the same info as filterbank features,
# but MFCC is more easily compressible (because less correlated) which is why
# we prefer this method.
--use-energy=false # use average of log energy, not energy.
--sample-frequency=16000
--num-mel-bins=40
--num-ceps=40
--low-freq=40
--high-freq=-400
1 change: 1 addition & 0 deletions egs/libri_css/s5_css/conf/online_cmvn.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# configuration file for apply-cmvn-online, used in the script ../local/run_online_decoding.sh
1 change: 1 addition & 0 deletions egs/libri_css/s5_css/diarization
1 change: 1 addition & 0 deletions egs/libri_css/s5_css/local
9 changes: 9 additions & 0 deletions egs/libri_css/s5_css/path.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
export KALDI_ROOT=`pwd`/../../..
[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh
export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/tools/sctk/bin:$PWD:$PATH
export PATH=$PWD/dscore:$PATH
export PYTHONPATH="${PYTHONPATH}:$PWD/dscore"
[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1
. $KALDI_ROOT/tools/config/common_path.sh
export LC_ALL=C

1 change: 1 addition & 0 deletions egs/libri_css/s5_css/rnnlm
Loading

0 comments on commit 1670662

Please sign in to comment.