Skip to content

Commit

Permalink
added overlap-aware sc
Browse files Browse the repository at this point in the history
  • Loading branch information
desh2608 committed Nov 5, 2020
1 parent fd97c98 commit 2e07a9e
Show file tree
Hide file tree
Showing 4 changed files with 563 additions and 0 deletions.
21 changes: 21 additions & 0 deletions egs/ami/s5c/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
# 1. Diarization using x-vector and clustering (AHC, VBx, spectral)
# 2. Training an overlap detector (using annotations) and corresponding
# inference on full recordings.
# 3. Overlap-aware spectral clustering

# The overlap-aware spectral clustering method is based on the paper:
# D.Raj, Z.Huang, S.Khudanpur, "Multi-class spectral clustering with
# overlaps for speaker diarization", IEEE SLT 2021.

# We do not provide training script for an x-vector extractor. You
# can download a pretrained extractor from:
Expand Down Expand Up @@ -163,3 +168,19 @@ if [ $stage -le 9 ]; then
done
fi

# The following stage demonstrates overlap-aware spectral clustering using the
# output of the overlap detector from the previous stage
if [ $stage -le 10 ]; then
for datadir in ${test_sets}; do
ref_rttm=data/${datadir}/rttm.annotation

nj=$( cat data/$datadir/wav.scp | wc -l )
local/diarize_spectral_ol.sh --nj $nj --cmd "$train_cmd" --stage $diarizer_stage \
$model_dir data/${datadir} exp/overlap_$overlap_affix/$datadir/rttm_overlap \
exp/${datadir}_diarization_spectral_ol

# Evaluate RTTM using md-eval.pl
md-eval.pl -r $ref_rttm -s exp/${datadir}_diarization_spectral_ol/rttm
done
fi

144 changes: 144 additions & 0 deletions egs/callhome_diarization/v1/diarization/make_rttm_ol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#!/usr/bin/env python

# Copyright 2020 Desh Raj
# Apache 2.0.

"""This script converts a segments and labels file to a NIST RTTM
file. It can handle overlapping segmentation.
The segments file format is:
<segment-id> <recording-id> <start-time> <end-time>
The labels file format is:
<segment-id> <speaker-id>
The output RTTM format is:
<type> <file> <chnl> <tbeg> \
<tdur> <ortho> <stype> <name> <conf> <slat>
where:
<type> = "SPEAKER"
<file> = <recording-id>
<chnl> = "0"
<tbeg> = start time of segment
<tdur> = duration of segment
<ortho> = "<NA>"
<stype> = "<NA>"
<name> = <speaker-id>
<conf> = "<NA>"
<slat> = "<NA>"
"""

import argparse
import sys

class Segment:
"""Stores all information about a segment"""
def __init__(self, reco_id, start_time, end_time, labels):
self.reco_id = reco_id
self.start_time = start_time
self.end_time = end_time
self.dur = end_time - start_time
self.labels = labels

def get_args():
parser = argparse.ArgumentParser(
description="""This script converts a segments and labels file
to a NIST RTTM file. It handles overlapping segments (e.g. the
output of a sliding-window diarization system).""")

parser.add_argument("segments", type=str,
help="Input segments file")
parser.add_argument("labels", type=str,
help="Input labels file")
parser.add_argument("rttm_file", type=str,
help="Output RTTM file")
parser.add_argument("--rttm-channel", type=int, default=0,
help="The value passed into the RTTM channel field. \
Only affects the format of the RTTM file.")

args = parser.parse_args()
return args

def main():
args = get_args()

# File containing speaker labels per segment
seg2label = {}
with open(args.labels, 'r') as labels_file:
for line in labels_file:
seg, label = line.strip().split()
if seg in seg2label:
seg2label[seg].append(label)
else:
seg2label[seg] = [label]

# Segments file
reco2segs = {}
with open(args.segments, 'r') as segments_file:
for line in segments_file:
seg, reco, start, end = line.strip().split()
try:
if reco in reco2segs:
reco2segs[reco].append(Segment(reco, float(start), float(end), seg2label[seg]))
else:
reco2segs[reco] = [Segment(reco, float(start), float(end), seg2label[seg])]
except KeyError:
raise RuntimeError("Missing label for segment {0}".format(seg))

# At this point the subsegments are overlapping, since we got them from a
# sliding window diarization method. We make them contiguous here
reco2contiguous = {}
for reco in sorted(reco2segs):
segs = sorted(reco2segs[reco], key=lambda x: x.start_time)
new_segs = []
for i, seg in enumerate(segs):
# If it is last segment in recording or last contiguous segment, add it to new_segs
if (i == len(segs)-1 or seg.end_time <= segs[i+1].start_time):
new_segs.append(Segment(reco, seg.start_time, seg.end_time, seg.labels))
# Otherwise split overlapping interval between current and next segment
else:
avg = (segs[i+1].start_time + seg.end_time) / 2
new_segs.append(Segment(reco, seg.start_time, avg, seg.labels))
segs[i+1].start_time = avg
reco2contiguous[reco] = new_segs

# Merge contiguous segments of the same label
reco2merged = {}
for reco in reco2contiguous:
segs = reco2contiguous[reco]
new_segs = []
running_labels = {} # {label: (start_time, end_time)}
for i, seg in enumerate(segs):
# If running labels are not present in current segment, add those segments
# to new_segs list and delete those entries
for label in list(running_labels):
if label not in seg.labels:
new_segs.append(Segment(reco, running_labels[label][0], running_labels[label][1], label))
del running_labels[label]
# Now add/update labels in running_labels based on current segment
for label in seg.labels:
if label in running_labels:
# If already present, just update end time
start_time = running_labels[label][0]
running_labels[label] = (start_time, seg.end_time)
else:
# Otherwise make a new entry
running_labels[label] = (seg.start_time, seg.end_time)
# If it is the last segment in utterance or last contiguous segment, add it to new_segs
# and delete from running_labels
if (i == len(segs)-1 or seg.end_time < segs[i+1].start_time):
# Case when it is last segment or if next segment is after some gap
for label in list(running_labels):
new_segs.append(Segment(reco, running_labels[label][0], running_labels[label][1], label))
del running_labels[label]
reco2merged[reco] = new_segs

with open(args.rttm_file, 'w') as rttm_writer:
for reco in reco2merged:
segs = reco2merged[reco]
for seg in segs:
for label in seg.labels:
rttm_writer.write("SPEAKER {0} {1} {2:7.3f} {3:7.3f} <NA> <NA> {4} <NA> <NA>\n".format(
reco, args.rttm_channel, seg.start_time, seg.dur, label))

if __name__ == '__main__':
main()
129 changes: 129 additions & 0 deletions egs/callhome_diarization/v1/diarization/scluster_ol.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#!/bin/bash

# Copyright 2016 David Snyder
# 2017-2018 Matthew Maciejewski
# 2020 Maxim Korenevsky (STC-innovations Ltd)
# Apache 2.0.

# This script performs spectral clustering using scored
# pairs of subsegments and produces a rttm file with speaker
# labels derived from the clusters.

# Begin configuration section.
cmd="run.pl"
stage=0
nj=10
cleanup=true
rttm_channel=0
reco2num_spk=
overlap_rttm= # Path to an RTTM output of an external overlap detector
rttm_affix=

# End configuration section.

echo "$0 $@" # Print the command line for logging

if [ -f path.sh ]; then . ./path.sh; fi
. parse_options.sh || exit 1;


if [ $# != 2 ]; then
echo "Usage: $0 <src-dir> <dir>"
echo " e.g.: $0 exp/ivectors_callhome exp/ivectors_callhome/results"
echo "main options (for others, see top of script file)"
echo " --config <config-file> # config containing options"
echo " --cmd (utils/run.pl|utils/queue.pl <queue opts>) # how to run jobs."
echo " --nj <n|10> # Number of jobs (also see num-processes and num-threads)"
echo " --stage <stage|0> # To control partial reruns"
echo " --rttm-channel <rttm-channel|0> # The value passed into the RTTM channel field. Only affects"
echo " # the format of the RTTM file."
echo " --reco2num-spk <reco2num-spk-file> # File containing mapping of recording ID"
echo " # to number of speakers. Used instead of threshold"
echo " # as stopping criterion if supplied."
echo " --overlap-rttm <overlap-rttm-file> # File containing overlap segments"
echo " --cleanup <bool|false> # If true, remove temporary files"
exit 1;
fi

srcdir=$1
dir=$2

reco2num_spk_opts=
if [ ! $reco2num_spk == "" ]; then
reco2num_spk_opts="--reco2num-spk $reco2num_spk"
fi

mkdir -p $dir/tmp

for f in $srcdir/scores.scp $srcdir/spk2utt $srcdir/utt2spk $srcdir/segments ; do
[ ! -f $f ] && echo "No such file $f" && exit 1;
done

# We use a different Python version in which the local
# scikit-learn is installed.
miniconda_dir=$HOME/miniconda3/
if [ ! -d $miniconda_dir ]; then
echo "$miniconda_dir does not exist. Please run '$KALDI_ROOT/tools/extras/install_miniconda.sh'."
exit 1
fi

overlap_rttm_opt=
if ! [ -z "$overlap_rttm" ]; then
overlap_rttm_opt="--overlap_rttm $overlap_rttm"
sc_bin="spec_clust_overlap.py"
rttm_bin="make_rttm_ol.py"
# Install a modified version of scikit-learn using:
echo "The overlap-aware spectral clustering requires installing a modified version\n"
echo "of scitkit-learn. You can download it using:\n"
echo "$miniconda_dir/bin/python -m pip install git+https://github.com/desh2608/scikit-learn.git@overlap \n"
echo "if the process fails while clustering."
else
sc_bin="spec_clust.py"
rttm_bin="make_rttm.py"
fi

cp $srcdir/spk2utt $dir/tmp/
cp $srcdir/utt2spk $dir/tmp/
cp $srcdir/segments $dir/tmp/
utils/fix_data_dir.sh $dir/tmp > /dev/null

if [ ! -z $reco2num_spk ]; then
reco2num_spk="ark,t:$reco2num_spk"
fi

sdata=$dir/tmp/split$nj;
utils/split_data.sh $dir/tmp $nj || exit 1;

# Set various variables.
mkdir -p $dir/log

feats="utils/filter_scp.pl $sdata/JOB/spk2utt $srcdir/scores.scp |"

reco2num_spk_opt=
if [ ! $reco2num_spk == "" ]; then
reco2num_spk_opt="--reco2num_spk $reco2num_spk"
fi

if [ $stage -le 0 ]; then
echo "$0: clustering scores"
for j in `seq $nj`; do
utils/filter_scp.pl $sdata/$j/spk2utt $srcdir/scores.scp > $dir/scores.$j.scp
done
$cmd JOB=1:$nj $dir/log/spectral_cluster.JOB.log \
$miniconda_dir/bin/python diarization/$sc_bin $reco2num_spk_opt $overlap_rttm_opt \
scp:$dir/scores.JOB.scp ark,t:$sdata/JOB/spk2utt ark,t:$dir/labels.JOB || exit 1;
fi

if [ $stage -le 1 ]; then
echo "$0: combining labels"
for j in $(seq $nj); do cat $dir/labels.$j; done > $dir/labels || exit 1;
fi

if [ $stage -le 2 ]; then
echo "$0: computing RTTM"
diarization/$rttm_bin --rttm-channel $rttm_channel $srcdir/segments $dir/labels $dir/rttm${rttm_affix} || exit 1;
fi

if $cleanup ; then
rm -r $dir/tmp || exit 1;
fi
Loading

0 comments on commit 2e07a9e

Please sign in to comment.