forked from kaldi-asr/kaldi
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
563 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
144 changes: 144 additions & 0 deletions
144
egs/callhome_diarization/v1/diarization/make_rttm_ol.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
#!/usr/bin/env python | ||
|
||
# Copyright 2020 Desh Raj | ||
# Apache 2.0. | ||
|
||
"""This script converts a segments and labels file to a NIST RTTM | ||
file. It can handle overlapping segmentation. | ||
The segments file format is: | ||
<segment-id> <recording-id> <start-time> <end-time> | ||
The labels file format is: | ||
<segment-id> <speaker-id> | ||
The output RTTM format is: | ||
<type> <file> <chnl> <tbeg> \ | ||
<tdur> <ortho> <stype> <name> <conf> <slat> | ||
where: | ||
<type> = "SPEAKER" | ||
<file> = <recording-id> | ||
<chnl> = "0" | ||
<tbeg> = start time of segment | ||
<tdur> = duration of segment | ||
<ortho> = "<NA>" | ||
<stype> = "<NA>" | ||
<name> = <speaker-id> | ||
<conf> = "<NA>" | ||
<slat> = "<NA>" | ||
""" | ||
|
||
import argparse | ||
import sys | ||
|
||
class Segment: | ||
"""Stores all information about a segment""" | ||
def __init__(self, reco_id, start_time, end_time, labels): | ||
self.reco_id = reco_id | ||
self.start_time = start_time | ||
self.end_time = end_time | ||
self.dur = end_time - start_time | ||
self.labels = labels | ||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser( | ||
description="""This script converts a segments and labels file | ||
to a NIST RTTM file. It handles overlapping segments (e.g. the | ||
output of a sliding-window diarization system).""") | ||
|
||
parser.add_argument("segments", type=str, | ||
help="Input segments file") | ||
parser.add_argument("labels", type=str, | ||
help="Input labels file") | ||
parser.add_argument("rttm_file", type=str, | ||
help="Output RTTM file") | ||
parser.add_argument("--rttm-channel", type=int, default=0, | ||
help="The value passed into the RTTM channel field. \ | ||
Only affects the format of the RTTM file.") | ||
|
||
args = parser.parse_args() | ||
return args | ||
|
||
def main(): | ||
args = get_args() | ||
|
||
# File containing speaker labels per segment | ||
seg2label = {} | ||
with open(args.labels, 'r') as labels_file: | ||
for line in labels_file: | ||
seg, label = line.strip().split() | ||
if seg in seg2label: | ||
seg2label[seg].append(label) | ||
else: | ||
seg2label[seg] = [label] | ||
|
||
# Segments file | ||
reco2segs = {} | ||
with open(args.segments, 'r') as segments_file: | ||
for line in segments_file: | ||
seg, reco, start, end = line.strip().split() | ||
try: | ||
if reco in reco2segs: | ||
reco2segs[reco].append(Segment(reco, float(start), float(end), seg2label[seg])) | ||
else: | ||
reco2segs[reco] = [Segment(reco, float(start), float(end), seg2label[seg])] | ||
except KeyError: | ||
raise RuntimeError("Missing label for segment {0}".format(seg)) | ||
|
||
# At this point the subsegments are overlapping, since we got them from a | ||
# sliding window diarization method. We make them contiguous here | ||
reco2contiguous = {} | ||
for reco in sorted(reco2segs): | ||
segs = sorted(reco2segs[reco], key=lambda x: x.start_time) | ||
new_segs = [] | ||
for i, seg in enumerate(segs): | ||
# If it is last segment in recording or last contiguous segment, add it to new_segs | ||
if (i == len(segs)-1 or seg.end_time <= segs[i+1].start_time): | ||
new_segs.append(Segment(reco, seg.start_time, seg.end_time, seg.labels)) | ||
# Otherwise split overlapping interval between current and next segment | ||
else: | ||
avg = (segs[i+1].start_time + seg.end_time) / 2 | ||
new_segs.append(Segment(reco, seg.start_time, avg, seg.labels)) | ||
segs[i+1].start_time = avg | ||
reco2contiguous[reco] = new_segs | ||
|
||
# Merge contiguous segments of the same label | ||
reco2merged = {} | ||
for reco in reco2contiguous: | ||
segs = reco2contiguous[reco] | ||
new_segs = [] | ||
running_labels = {} # {label: (start_time, end_time)} | ||
for i, seg in enumerate(segs): | ||
# If running labels are not present in current segment, add those segments | ||
# to new_segs list and delete those entries | ||
for label in list(running_labels): | ||
if label not in seg.labels: | ||
new_segs.append(Segment(reco, running_labels[label][0], running_labels[label][1], label)) | ||
del running_labels[label] | ||
# Now add/update labels in running_labels based on current segment | ||
for label in seg.labels: | ||
if label in running_labels: | ||
# If already present, just update end time | ||
start_time = running_labels[label][0] | ||
running_labels[label] = (start_time, seg.end_time) | ||
else: | ||
# Otherwise make a new entry | ||
running_labels[label] = (seg.start_time, seg.end_time) | ||
# If it is the last segment in utterance or last contiguous segment, add it to new_segs | ||
# and delete from running_labels | ||
if (i == len(segs)-1 or seg.end_time < segs[i+1].start_time): | ||
# Case when it is last segment or if next segment is after some gap | ||
for label in list(running_labels): | ||
new_segs.append(Segment(reco, running_labels[label][0], running_labels[label][1], label)) | ||
del running_labels[label] | ||
reco2merged[reco] = new_segs | ||
|
||
with open(args.rttm_file, 'w') as rttm_writer: | ||
for reco in reco2merged: | ||
segs = reco2merged[reco] | ||
for seg in segs: | ||
for label in seg.labels: | ||
rttm_writer.write("SPEAKER {0} {1} {2:7.3f} {3:7.3f} <NA> <NA> {4} <NA> <NA>\n".format( | ||
reco, args.rttm_channel, seg.start_time, seg.dur, label)) | ||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
#!/bin/bash | ||
|
||
# Copyright 2016 David Snyder | ||
# 2017-2018 Matthew Maciejewski | ||
# 2020 Maxim Korenevsky (STC-innovations Ltd) | ||
# Apache 2.0. | ||
|
||
# This script performs spectral clustering using scored | ||
# pairs of subsegments and produces a rttm file with speaker | ||
# labels derived from the clusters. | ||
|
||
# Begin configuration section. | ||
cmd="run.pl" | ||
stage=0 | ||
nj=10 | ||
cleanup=true | ||
rttm_channel=0 | ||
reco2num_spk= | ||
overlap_rttm= # Path to an RTTM output of an external overlap detector | ||
rttm_affix= | ||
|
||
# End configuration section. | ||
|
||
echo "$0 $@" # Print the command line for logging | ||
|
||
if [ -f path.sh ]; then . ./path.sh; fi | ||
. parse_options.sh || exit 1; | ||
|
||
|
||
if [ $# != 2 ]; then | ||
echo "Usage: $0 <src-dir> <dir>" | ||
echo " e.g.: $0 exp/ivectors_callhome exp/ivectors_callhome/results" | ||
echo "main options (for others, see top of script file)" | ||
echo " --config <config-file> # config containing options" | ||
echo " --cmd (utils/run.pl|utils/queue.pl <queue opts>) # how to run jobs." | ||
echo " --nj <n|10> # Number of jobs (also see num-processes and num-threads)" | ||
echo " --stage <stage|0> # To control partial reruns" | ||
echo " --rttm-channel <rttm-channel|0> # The value passed into the RTTM channel field. Only affects" | ||
echo " # the format of the RTTM file." | ||
echo " --reco2num-spk <reco2num-spk-file> # File containing mapping of recording ID" | ||
echo " # to number of speakers. Used instead of threshold" | ||
echo " # as stopping criterion if supplied." | ||
echo " --overlap-rttm <overlap-rttm-file> # File containing overlap segments" | ||
echo " --cleanup <bool|false> # If true, remove temporary files" | ||
exit 1; | ||
fi | ||
|
||
srcdir=$1 | ||
dir=$2 | ||
|
||
reco2num_spk_opts= | ||
if [ ! $reco2num_spk == "" ]; then | ||
reco2num_spk_opts="--reco2num-spk $reco2num_spk" | ||
fi | ||
|
||
mkdir -p $dir/tmp | ||
|
||
for f in $srcdir/scores.scp $srcdir/spk2utt $srcdir/utt2spk $srcdir/segments ; do | ||
[ ! -f $f ] && echo "No such file $f" && exit 1; | ||
done | ||
|
||
# We use a different Python version in which the local | ||
# scikit-learn is installed. | ||
miniconda_dir=$HOME/miniconda3/ | ||
if [ ! -d $miniconda_dir ]; then | ||
echo "$miniconda_dir does not exist. Please run '$KALDI_ROOT/tools/extras/install_miniconda.sh'." | ||
exit 1 | ||
fi | ||
|
||
overlap_rttm_opt= | ||
if ! [ -z "$overlap_rttm" ]; then | ||
overlap_rttm_opt="--overlap_rttm $overlap_rttm" | ||
sc_bin="spec_clust_overlap.py" | ||
rttm_bin="make_rttm_ol.py" | ||
# Install a modified version of scikit-learn using: | ||
echo "The overlap-aware spectral clustering requires installing a modified version\n" | ||
echo "of scitkit-learn. You can download it using:\n" | ||
echo "$miniconda_dir/bin/python -m pip install git+https://github.com/desh2608/scikit-learn.git@overlap \n" | ||
echo "if the process fails while clustering." | ||
else | ||
sc_bin="spec_clust.py" | ||
rttm_bin="make_rttm.py" | ||
fi | ||
|
||
cp $srcdir/spk2utt $dir/tmp/ | ||
cp $srcdir/utt2spk $dir/tmp/ | ||
cp $srcdir/segments $dir/tmp/ | ||
utils/fix_data_dir.sh $dir/tmp > /dev/null | ||
|
||
if [ ! -z $reco2num_spk ]; then | ||
reco2num_spk="ark,t:$reco2num_spk" | ||
fi | ||
|
||
sdata=$dir/tmp/split$nj; | ||
utils/split_data.sh $dir/tmp $nj || exit 1; | ||
|
||
# Set various variables. | ||
mkdir -p $dir/log | ||
|
||
feats="utils/filter_scp.pl $sdata/JOB/spk2utt $srcdir/scores.scp |" | ||
|
||
reco2num_spk_opt= | ||
if [ ! $reco2num_spk == "" ]; then | ||
reco2num_spk_opt="--reco2num_spk $reco2num_spk" | ||
fi | ||
|
||
if [ $stage -le 0 ]; then | ||
echo "$0: clustering scores" | ||
for j in `seq $nj`; do | ||
utils/filter_scp.pl $sdata/$j/spk2utt $srcdir/scores.scp > $dir/scores.$j.scp | ||
done | ||
$cmd JOB=1:$nj $dir/log/spectral_cluster.JOB.log \ | ||
$miniconda_dir/bin/python diarization/$sc_bin $reco2num_spk_opt $overlap_rttm_opt \ | ||
scp:$dir/scores.JOB.scp ark,t:$sdata/JOB/spk2utt ark,t:$dir/labels.JOB || exit 1; | ||
fi | ||
|
||
if [ $stage -le 1 ]; then | ||
echo "$0: combining labels" | ||
for j in $(seq $nj); do cat $dir/labels.$j; done > $dir/labels || exit 1; | ||
fi | ||
|
||
if [ $stage -le 2 ]; then | ||
echo "$0: computing RTTM" | ||
diarization/$rttm_bin --rttm-channel $rttm_channel $srcdir/segments $dir/labels $dir/rttm${rttm_affix} || exit 1; | ||
fi | ||
|
||
if $cleanup ; then | ||
rm -r $dir/tmp || exit 1; | ||
fi |
Oops, something went wrong.