diff --git a/users/gruev/lexicon/bpe_lexicon.py b/users/gruev/lexicon/bpe_lexicon.py
new file mode 100644
index 000000000..14d7fdd77
--- /dev/null
+++ b/users/gruev/lexicon/bpe_lexicon.py
@@ -0,0 +1,333 @@
+__all__ = ["CreateBPELexiconJob", "ApplyBPEToTextJob"]
+
+import os
+import sys
+import shutil
+import tempfile
+import subprocess as sp
+
+import xml.etree.ElementTree as ET
+from typing import Optional
+
+import i6_core.util as util
+from i6_core.lib.lexicon import Lexicon, Lemma
+from sisyphus import tk, gs, Task, Job
+
+
+class CreateBPELexiconJob(Job):
+ """
+ Create a Bliss lexicon from bpe transcriptions.
+ """
+
+ def __init__(
+ self,
+ base_lexicon_path,
+ bpe_codes,
+ bpe_vocab,
+ subword_nmt_repo=None,
+ unk_label="UNK",
+ ):
+ """
+ :param Path base_lexicon_path:
+ :param Path bpe_codes:
+ :param Path|None bpe_vocab:
+ :param Path|str|None subword_nmt_repo:
+ """
+ self.base_lexicon_path = base_lexicon_path
+ self.bpe_codes = bpe_codes
+ self.bpe_vocab = bpe_vocab
+ self.subword_nmt_repo = (
+ subword_nmt_repo if subword_nmt_repo is not None else gs.SUBWORD_NMT_PATH
+ )
+ self.unk_label = unk_label
+
+ self.out_lexicon = self.output_path("lexicon.xml.gz", cached=True)
+
+ def tasks(self):
+ yield Task("run", resume="run", mini_task=True)
+
+ def run(self):
+ lexicon = Lexicon()
+
+ lm_tokens = set()
+
+ base_lexicon = Lexicon()
+ base_lexicon.load(self.base_lexicon_path)
+ for l in base_lexicon.lemmata:
+ if l.special is None:
+ for orth in l.orth:
+ lm_tokens.add(orth)
+ for token in l.synt or []: # l.synt can be None
+ lm_tokens.add(token)
+ for eval in l.eval:
+ for t in eval:
+ lm_tokens.add(t)
+
+ lm_tokens = list(lm_tokens)
+
+ with util.uopen("words", "wt") as f:
+ for t in lm_tokens:
+ f.write(f"{t}\n")
+
+ vocab = set()
+ with util.uopen(self.bpe_vocab.get_path(), "rt") as f, util.uopen("fake_count_vocab.txt", "wt") as vocab_file:
+ for line in f:
+ if "{" in line or "" in line or "" in line or "}" in line:
+ continue
+ symbol = line.split(":")[0][1:-1]
+ if symbol != self.unk_label:
+ vocab_file.write(symbol + " -1\n")
+ symbol = symbol.replace(".", "_")
+ vocab.add(symbol)
+ lexicon.add_phoneme(symbol.replace(".", "_"))
+
+ apply_binary = os.path.join(
+ tk.uncached_path(self.subword_nmt_repo), "apply_bpe.py"
+ )
+ args = [
+ sys.executable,
+ apply_binary,
+ "--input",
+ "words",
+ "--codes",
+ self.bpe_codes.get_path(),
+ "--vocabulary",
+ "fake_count_vocab.txt",
+ "--output",
+ "bpes",
+ ]
+ sp.run(args, check=True)
+
+ with util.uopen("bpes", "rt") as f:
+ bpe_tokens = [l.strip() for l in f]
+
+ w2b = {w: b for w, b in zip(lm_tokens, bpe_tokens)}
+
+ for w, b in w2b.items():
+ b = " ".join(
+ [token if token in vocab else self.unk_label for token in b.split()]
+ )
+ lexicon.add_lemma(Lemma([w], [b.replace(".", "_")]))
+
+ elem = lexicon.to_xml()
+ tree = ET.ElementTree(elem)
+ util.write_xml(self.out_lexicon.get_path(), tree)
+
+
+# class CreateBPELexiconJob(Job):
+# """
+# Create a Bliss lexicon from BPE transcriptions.
+# """
+#
+# def __init__(
+# self,
+# base_lexicon_path: tk.Path,
+# bpe_codes: tk.Path,
+# bpe_vocab: tk.Path,
+# subword_nmt_repo: Optional[tk.Path] = None,
+# unk_label: str = "[UNKNOWN]",
+# add_silence: bool = True,
+# add_other_special: bool = False,
+# ):
+# """
+# :param tk.Path base_lexicon_path: path to a Bliss lexicon
+# :param tk.Path bpe_codes: path to BPE codes produced by i.e. ReturnnTrainBPEJob
+# :param tk.Path bpe_vocab: path to BPE vocab produced by i.e. ReturnnTrainBPEJob
+# :param tk.Path|None subword_nmt_repo:
+# :param str unk_label:
+# :param bool add_silence: explicitly include a [SILENCE] phoneme and lemma
+# :param bool add_other_special: explicitly include special lemmata from base_lexicon_path
+# """
+#
+# self.base_lexicon_path = base_lexicon_path
+# self.bpe_codes = bpe_codes
+# self.bpe_vocab = bpe_vocab
+# self.subword_nmt_repo = (
+# subword_nmt_repo if subword_nmt_repo is not None else gs.SUBWORD_NMT_PATH
+# )
+# self.unk_label = unk_label
+# self.add_silence = add_silence
+# self.add_other_special = add_other_special
+#
+# self.out_lexicon = self.output_path("lexicon.xml.gz", cached=True)
+#
+# def tasks(self):
+# yield Task("run", resume="run", mini_task=True)
+#
+# def run(self):
+# lm_tokens = set()
+# other_special = []
+#
+# base_lexicon = Lexicon()
+# base_lexicon.load(self.base_lexicon_path)
+#
+# for l in base_lexicon.lemmata:
+# if l.special:
+# if l.special not in ["silence", "unknown"]:
+# other_special.append(l)
+# continue
+# for orth in l.orth or []: # l.orth can be None
+# lm_tokens.add(orth)
+# for token in l.synt or []: # l.synt can be None
+# lm_tokens.add(token)
+# for eval in l.eval or []: # l.eval can be None
+# for t in eval:
+# lm_tokens.add(t)
+#
+# lm_tokens = [lt for lt in lm_tokens if lt != ''] # catch , or
+#
+# with util.uopen("words", "wt") as f:
+# for t in lm_tokens:
+# f.write(f"{t}\n")
+#
+# vocab = set()
+# lexicon = Lexicon()
+#
+# lexicon.add_phoneme(self.unk_label, variation="none")
+#
+# if self.add_silence:
+# lexicon.add_phoneme("[SILENCE]", variation="none")
+#
+# with util.uopen(self.bpe_vocab.get_path(), "rt") as bpe_vocab_file:
+# with util.uopen("fake_count_vocab", "wt") as fake_count_file:
+# for line in bpe_vocab_file:
+# if "{" in line or "" in line or "" in line or "}" in line:
+# continue
+# symbol = line.split(":")[0][1:-1]
+# if symbol != self.unk_label:
+# fake_count_file.write(symbol + " -1\n")
+# symbol = symbol.replace(".", "_")
+# vocab.add(symbol)
+# lexicon.add_phoneme(symbol)
+#
+# apply_binary = os.path.join(
+# tk.uncached_path(self.subword_nmt_repo), "apply_bpe.py"
+# )
+# args = [
+# sys.executable,
+# apply_binary,
+# "--input",
+# "words",
+# "--codes",
+# self.bpe_codes.get_path(),
+# "--vocabulary",
+# "fake_count_vocab",
+# "--output",
+# "bpes",
+# ]
+# sp.run(args, check=True)
+#
+# with util.uopen("bpes", "rt") as f:
+# bpe_tokens = [l.strip() for l in f]
+#
+# w2b = {w: b for w, b in zip(lm_tokens, bpe_tokens)}
+#
+# lexicon.add_lemma(
+# Lemma(["[UNKNOWN]"], [self.unk_label], None, None, special="unknown")
+# )
+#
+# if self.add_silence:
+# lexicon.add_lemma(
+# Lemma(["[SILENCE]"], ["[SILENCE]"], [], [[]], special="silence")
+# )
+#
+# if self.add_other_special:
+# for l in other_special:
+# lexicon.add_lemma(l)
+#
+# for w, b in w2b.items():
+# b = " ".join(
+# [token if token in vocab else self.unk_label for token in b.split()]
+# )
+# lexicon.add_lemma(Lemma([w], [b.replace(".", "_")]))
+#
+# elem = lexicon.to_xml()
+# tree = ET.ElementTree(elem)
+# util.write_xml(self.out_lexicon.get_path(), tree)
+
+
+class ApplyBPEToTextJob(Job):
+ """
+ Apply BPE codes on a text file
+ """
+
+ __sis_hash_exclude__ = {"gzip_output": False}
+
+ def __init__(
+ self,
+ words_file: tk.Path,
+ bpe_codes: tk.Path,
+ bpe_vocab: tk.Path,
+ subword_nmt_repo: Optional[tk.Path] = None,
+ gzip_output: bool = False,
+ mini_task: bool = True,
+ ):
+ """
+ :param tk.Path text_file: path to a words text file
+ :param tk.Path bpe_codes: path to BPE codes file, use e.g. ReturnnTrainBpeJob.out_bpe_codes
+ :param tk.Path bpe_vocab: path to BPE vocab file used to revert merge operations that produce OOV,
+ use e.g. ReturnnTrainBPEJob.out_bpe_vocab;
+ :param tk.Path/None subword_nmt_repo: path to subword nmt repository , see also `CloneGitRepositoryJob`
+ :param bool gzip_output: use gzip on the output text
+ :param bool mini_task: if the Job should run locally, e.g. only a small (<1M lines) text should be processed
+ """
+ self.words_file = words_file
+ self.bpe_codes = bpe_codes
+ self.bpe_vocab = bpe_vocab
+ self.subword_nmt_repo = (
+ subword_nmt_repo if subword_nmt_repo is not None else gs.SUBWORD_NMT_PATH
+ )
+ self.gzip_output = gzip_output
+
+ self.out_bpe_text = self.output_path("words_to_bpe.txt.gz" if gzip_output else "words_to_bpe.txt")
+
+ self.mini_task = mini_task
+ self.rqmt = {"cpu": 1, "mem": 2, "time": 2}
+
+ def tasks(self):
+ if self.mini_task:
+ yield Task("run", mini_task=True)
+ else:
+ yield Task("run", rqmt=self.rqmt)
+
+ def run(self):
+ with tempfile.TemporaryDirectory(prefix=gs.TMP_PREFIX) as tmp:
+ words_file = self.words_file.get_path()
+ tmp_outfile = os.path.join(tmp, "out_text.txt")
+
+ with util.uopen(self.bpe_vocab.get_path(), "rt") as bpe_vocab_file:
+ with util.uopen("fake_count_vocab.txt", "wt") as fake_count_file:
+ for line in bpe_vocab_file:
+ if "{" in line or "<" in line or "[" in line or "]" in line or ">" in line or "}" in line:
+ continue
+ symbol = line.split(":")[0][1:-1]
+ fake_count_file.write(symbol + " -1\n")
+
+ apply_binary = os.path.join(
+ tk.uncached_path(self.subword_nmt_repo), "apply_bpe.py"
+ )
+ cmd = [
+ sys.executable,
+ apply_binary,
+ "--input",
+ words_file,
+ "--codes",
+ self.bpe_codes.get_path(),
+ "--vocabulary",
+ "fake_count_vocab.txt",
+ "--output",
+ tmp_outfile,
+ ]
+ util.create_executable("apply_bpe.sh", cmd)
+ sp.run(cmd, check=True)
+
+ if self.gzip_output:
+ with util.uopen(tmp_outfile, "rt") as fin, util.uopen(self.out_bpe_text, "wb") as fout:
+ sp.call(["gzip"], stdin=fin, stdout=fout)
+ else:
+ shutil.copy(tmp_outfile, self.out_bpe_text.get_path())
+
+ @classmethod
+ def hash(cls, parsed_args):
+ del parsed_args["mini_task"]
+ return super().hash(parsed_args)
diff --git a/users/gruev/statistics/alignment.py b/users/gruev/statistics/alignment.py
new file mode 100644
index 000000000..738783658
--- /dev/null
+++ b/users/gruev/statistics/alignment.py
@@ -0,0 +1,102 @@
+import subprocess
+import tempfile
+import shutil
+import os
+import json
+
+from recipe.i6_core.util import create_executable
+
+from sisyphus import *
+
+import recipe.i6_private.users.gruev.tools as tools_mod
+tools_dir = os.path.dirname(tools_mod.__file__)
+
+
+class AlignmentStatisticsJob(Job):
+ def __init__(
+ self,
+ alignment,
+ num_labels,
+ blank_idx=0,
+ seq_list_filter_file=None,
+ time_rqmt=2,
+ returnn_python_exe=None,
+ returnn_root=None,
+ ):
+
+ self.returnn_python_exe = (
+ returnn_python_exe
+ if returnn_python_exe is not None
+ else gs.RETURNN_PYTHON_EXE
+ )
+ self.returnn_root = (
+ returnn_root if returnn_root is not None else gs.RETURNN_ROOT
+ )
+
+ self.alignment = alignment
+ self.seq_list_filter_file = seq_list_filter_file
+ self.num_labels = num_labels
+ self.blank_idx = blank_idx
+
+ self.out_statistics = self.output_path("statistics.txt")
+ self.out_labels_hist = self.output_path("labels_histogram.pdf")
+ self.out_mean_label_seg_lens = self.output_path("mean_label_seg_lens.json")
+ # self.out_mean_label_seg_lens_var = self.output_var("mean_label_seg_lens_var.json")
+ self.out_mean_label_seg_len = self.output_path("mean_label_seg_len.txt")
+ # self.out_mean_label_seg_len_var = self.output_var("mean_label_seg_len_var.txt")
+ self.out_mean_label_seq_len = self.output_path("mean_label_seq_len.txt")
+ # self.out_mean_label_seq_len_var = self.output_var("mean_label_seq_len_var.txt")
+ self.out_90_quantile_var = self.output_var("quantile_90")
+ self.out_95_quantile_var = self.output_var("quantile_95")
+ self.out_99_quantile_var = self.output_var("quantile_99")
+
+ self.time_rqmt = time_rqmt
+
+ def tasks(self):
+ yield Task("run", rqmt={"cpu": 1, "mem": 4, "time": self.time_rqmt})
+
+ def run(self):
+ command = [
+ self.returnn_python_exe.get_path(),
+ os.path.join(tools_dir, "segment_statistics.py"),
+ self.alignment.get_path(),
+ "--num-labels",
+ str(self.num_labels),
+ "--blank-idx",
+ str(self.blank_idx),
+ "--returnn-root",
+ self.returnn_root.get_path(),
+ ]
+
+ if self.seq_list_filter_file:
+ command += ["--seq-list-filter-file", str(self.seq_list_filter_file)]
+
+ create_executable("rnn.sh", command)
+ subprocess.check_call(["./rnn.sh"])
+
+ # with open("mean_label_seg_lens.json", "r") as f:
+ # mean_label_seg_lens = json.load(f)
+ # mean_label_seg_lens = [
+ # mean_label_seg_lens[str(idx)] for idx in range(self.num_labels)
+ # ]
+ # self.out_mean_label_seg_lens_var.set(mean_label_seg_lens)
+ #
+ # with open("mean_label_seg_len.txt", "r") as f:
+ # self.out_mean_label_seg_len_var.set(float(f.read()))
+ #
+ # with open("mean_label_seq_len.txt", "r") as f:
+ # self.out_mean_label_seq_len_var.set(float(f.read()))
+
+ # Set quantiles
+ with open("quantile_90", "r") as f:
+ self.out_90_quantile_var.set(int(float(f.read())))
+ with open("quantile_95", "r") as f:
+ self.out_95_quantile_var.set(int(float(f.read())))
+ with open("quantile_99", "r") as f:
+ self.out_99_quantile_var.set(int(float(f.read())))
+
+ shutil.move("statistics.txt", self.out_statistics.get_path())
+ shutil.move("labels_histogram.pdf", self.out_labels_hist.get_path())
+ shutil.move("mean_label_seg_lens.json", self.out_mean_label_seg_lens.get_path())
+ shutil.move("mean_label_seg_len.txt", self.out_mean_label_seg_len.get_path())
+ shutil.move("mean_label_seq_len.txt", self.out_mean_label_seq_len.get_path())
\ No newline at end of file
diff --git a/users/gruev/statistics/bpe_statistics.py b/users/gruev/statistics/bpe_statistics.py
new file mode 100755
index 000000000..99fe5a782
--- /dev/null
+++ b/users/gruev/statistics/bpe_statistics.py
@@ -0,0 +1,122 @@
+import subprocess
+import shutil
+import os
+
+from recipe.i6_core.util import create_executable
+
+from sisyphus import *
+
+import recipe.i6_private.users.gruev.tools as tools_mod
+
+tools_dir = os.path.dirname(tools_mod.__file__)
+
+
+class BpeStatisticsJob(Job):
+ def __init__(self, bliss_lexicon, transcription, returnn_python_exe=None):
+ """
+ :param bliss_lexicon: Bliss lexicon with BPE subword units as sequences
+ :param transcription: A text file with utterance transcriptions
+ """
+
+ self.bliss_lexicon = bliss_lexicon
+ self.transcription = transcription
+
+ if returnn_python_exe is not None:
+ self.returnn_python_exe = returnn_python_exe
+ else:
+ self.returnn_python_exe = gs.RETURNN_PYTHON_EXE
+
+ ## Example:
+ # Text: AMAZINGLY COMPLICATED COMPLICATED WANAMAKER
+ # Segmentation: AMA@@ ZINGLY COMP@@ LIC@@ ATED COMP@@ LIC@@ ATED WAN@@ AMA@@ KER
+
+ # Num token per sequence: 1 sequence with 11 BPE tokens --> average is 11.
+ # Num tokens per word: 4 words with 2 + 3 + 3 + 3 tokens --> average is 11/4 = 2.75
+ # Num symbols per token: 7 tokens with 3 + 5 + 3 + 3 + 4 + 3 + 3 symbols --> average is 24/7 = 3.4
+ # Token count per vocab: AMA@@ appears two times, COMP@@ and LIC@@ appear one time (per unique word!)
+ # Token count per corpus: AMA@@ appears two times, COMP@@ and LIC@@ apear two times (per sequence!)
+
+ # Summary of all other statistics, register as output
+ self.out_bpe_statistics = self.output_path("bpe_statistics.txt")
+
+ # Average number of BPE tokens for a sequence (utterance)
+ self.out_mean_num_token_per_sequence = self.output_path(
+ "mean_num_token_per_sequence.txt"
+ )
+ self.out_num_token_per_sequence_histogram = self.output_path(
+ "num_token_per_sequence_histogram.pdf"
+ )
+
+ # Average number of BPE tokens for a single word in the corpus
+ self.out_mean_num_token_per_word = self.output_path(
+ "mean_num_token_per_word.txt"
+ )
+ self.out_num_token_per_word_histogram = self.output_path(
+ "num_token_per_word_histogram.pdf"
+ )
+
+ # Average number of symbols that comprise a BPE token
+ self.out_mean_num_symbols_per_token = self.output_path(
+ "mean_num_symbols_per_token.txt"
+ )
+ self.out_num_symbols_per_token_histogram = self.output_path(
+ "num_symbols_per_token_histogram.pdf"
+ )
+
+ # Number of BPE tokens per vocabulary (all words)
+ self.out_mean_token_count_per_vocab = self.output_path(
+ "mean_token_count_per_vocab.txt"
+ )
+ self.out_token_count_per_vocab = self.output_path("token_count_per_vocab.json")
+ self.out_token_count_per_vocab_plot = self.output_path(
+ "token_count_per_vocab_plot.pdf"
+ )
+
+ # Number of BPE tokens per corpus (all sequence)
+ self.out_mean_token_count_per_corpus = self.output_path(
+ "mean_token_count_per_corpus.txt"
+ )
+ self.out_token_count_per_corpus = self.output_path(
+ "token_count_per_corpus.json"
+ )
+ self.out_token_count_per_corpus_plot = self.output_path(
+ "token_count_per_corpus_plot.pdf"
+ )
+
+ # OOV words in corpus
+ self.out_oov_words = self.output_path("oov_words.json")
+
+ def tasks(self):
+ yield Task("run", rqmt={"cpu": 1, "mem": 2, "time": 2})
+
+ def run(self):
+ command = [
+ self.returnn_python_exe.get_path(),
+ os.path.join(tools_dir, "bpe_statistics.py"),
+ self.bliss_lexicon.get_path(),
+ self.transcription.get_path(),
+ ]
+
+ create_executable("rnn.sh", command)
+ subprocess.check_call(["./rnn.sh"])
+
+ # Register
+ shutil.move("bpe_statistics.txt", self.out_bpe_statistics.get_path())
+ shutil.move("oov_words.json", self.out_oov_words.get_path())
+
+ for (stat, fig) in [
+ ("num_token_per_sequence", "histogram"),
+ ("num_token_per_word", "histogram"),
+ ("num_symbols_per_token", "histogram"),
+ ("token_count_per_vocab", "plot"),
+ ("token_count_per_corpus", "plot"),
+ ]:
+ shutil.move(
+ f"mean_{stat}.txt", self.__dict__[f"out_mean_{stat}"].get_path()
+ )
+ shutil.move(
+ f"{stat}_{fig}.pdf", self.__dict__[f"out_{stat}_{fig}"].get_path()
+ )
+
+ if stat in ["token_count_per_vocab", "token_count_per_corpus"]:
+ shutil.move(f"{stat}.json", self.__dict__[f"out_{stat}"].get_path())
diff --git a/users/gruev/statistics/schmitt_alignment.py b/users/gruev/statistics/schmitt_alignment.py
new file mode 100644
index 000000000..aa10caa06
--- /dev/null
+++ b/users/gruev/statistics/schmitt_alignment.py
@@ -0,0 +1,94 @@
+from sisyphus import *
+
+from recipe.i6_core.util import create_executable
+from recipe.i6_core.rasr.config import build_config_from_mapping
+from recipe.i6_core.rasr.command import RasrCommand
+from i6_core.returnn.config import ReturnnConfig
+from i6_core.returnn.forward import ReturnnForwardJob
+
+from sisyphus import Path
+
+import subprocess
+import tempfile
+import shutil
+import os
+import json
+
+import recipe.i6_private.users.gruev.tools as tools_mod
+# import recipe.i6_experiments.users.schmitt.tools as tools_mod
+tools_dir = os.path.dirname(tools_mod.__file__)
+
+class AlignmentStatisticsJob(Job):
+ def __init__(self, alignment, seq_list_filter_file=None, blank_idx=0, silence_idx=None,
+ time_rqmt=2, returnn_python_exe=None,
+ returnn_root=None):
+ self.returnn_python_exe = (returnn_python_exe if returnn_python_exe is not None else gs.RETURNN_PYTHON_EXE)
+ self.returnn_root = (returnn_root if returnn_root is not None else gs.RETURNN_ROOT)
+
+ self.alignment = alignment
+ self.seq_list_filter_file = seq_list_filter_file
+ self.blank_idx = blank_idx
+ self.silence_idx = silence_idx
+ self.out_statistics = self.output_path("statistics")
+ self.out_sil_hist = self.output_path("sil_histogram.pdf")
+ self.out_non_sil_hist = self.output_path("non_sil_histogram.pdf")
+ self.out_label_dep_stats = self.output_path("label_dep_mean_lens")
+ # self.out_label_dep_vars = self.output_path("label_dep_mean_vars")
+ self.out_label_dep_stats_var = self.output_var("label_dep_mean_lens_var", pickle=True)
+ # self.out_label_dep_vars_var = self.output_var("label_dep_mean_vars_var", pickle=True)
+ self.out_mean_non_sil_len = self.output_path("mean_non_sil_len")
+ self.out_mean_non_sil_len_var = self.output_var("mean_non_sil_len_var")
+ self.out_95_percentile_var = self.output_var("percentile_95")
+ self.out_90_percentile_var = self.output_var("percentile_90")
+ self.out_99_percentile_var = self.output_var("percentile_99")
+
+ self.time_rqmt = time_rqmt
+
+ def tasks(self):
+ yield Task("run", rqmt={"cpu": 1, "mem": 4, "time": self.time_rqmt})
+
+ def run(self):
+ command = [
+ self.returnn_python_exe.get_path(),
+ os.path.join(tools_dir, "schmitt_segment_statistics.py"),
+ self.alignment.get_path(),
+ "--blank-idx", str(self.blank_idx), "--sil-idx", str(self.silence_idx),
+ "--returnn-root", self.returnn_root.get_path()
+ ]
+
+ if self.seq_list_filter_file:
+ command += ["--seq-list-filter-file", str(self.seq_list_filter_file)]
+
+ create_executable("rnn.sh", command)
+ subprocess.check_call(["./rnn.sh"])
+
+ with open("label_dep_mean_lens", "r") as f:
+ label_dep_means = json.load(f)
+ label_dep_means = {int(k): v for k, v in label_dep_means.items()}
+ label_dep_means = [label_dep_means[idx] for idx in range(len(label_dep_means)) if idx > 1]
+
+ # with open("label_dep_vars", "r") as f:
+ # label_dep_vars = json.load(f)
+ # label_dep_vars = {int(k): v for k, v in label_dep_vars.items()}
+ # label_dep_vars = [label_dep_vars[idx] for idx in range(len(label_dep_vars))]
+
+ with open("mean_non_sil_len", "r") as f:
+ self.out_mean_non_sil_len_var.set(float(f.read()))
+
+ # set percentiles
+ with open("percentile_90", "r") as f:
+ self.out_90_percentile_var.set(int(float(f.read())))
+ with open("percentile_95", "r") as f:
+ self.out_95_percentile_var.set(int(float(f.read())))
+ with open("percentile_99", "r") as f:
+ self.out_99_percentile_var.set(int(float(f.read())))
+
+ self.out_label_dep_stats_var.set(label_dep_means)
+ # self.out_label_dep_vars_var.set(label_dep_vars)
+
+ shutil.move("statistics", self.out_statistics.get_path())
+ shutil.move("sil_histogram.pdf", self.out_sil_hist.get_path())
+ shutil.move("non_sil_histogram.pdf", self.out_non_sil_hist.get_path())
+ shutil.move("label_dep_mean_lens", self.out_label_dep_stats.get_path())
+ # shutil.move("label_dep_vars", self.out_label_dep_vars.get_path())
+ shutil.move("mean_non_sil_len", self.out_mean_non_sil_len.get_path())
diff --git a/users/gruev/tools/bpe_statistics.py b/users/gruev/tools/bpe_statistics.py
new file mode 100755
index 000000000..d23352b04
--- /dev/null
+++ b/users/gruev/tools/bpe_statistics.py
@@ -0,0 +1,370 @@
+import sys
+import json
+import gzip
+import argparse
+import numpy as np
+from collections import Counter
+import matplotlib.pyplot as plt
+import xml.etree.ElementTree as ET
+
+
+# kwargs['xlabel']: Optional[str]
+# kwargs['ylabel']: Optional[str]
+# kwargs['percentiles']: Optinal[List[int]]
+
+
+def num_bin_heuristic(num_values):
+ """ Try to partition the data into 20-30 bins. """
+ num_values_rounded = np.round(num_values / 5) * 5
+
+ min_remainder, num_bins = 30, 0
+ for n in np.arange(20, 31):
+ if min_remainder > num_values_rounded % n:
+ min_remainder = num_values_rounded % n
+ num_bins = n
+
+ return num_bins
+
+
+def make_histogram(counter, alias, **kwargs):
+ # Data
+ data = [*counter.elements()]
+
+ max_num_bins = num_bin_heuristic(len(counter))
+ num_bins = min(len(counter) + 1, max_num_bins)
+ custom_binning = len(counter) < max_num_bins
+
+ # TODO: Add more kwargs for histogram settings, etc.
+
+ # Differentiate for smaller and larger number of bins
+ bins = max_num_bins
+ xticks = None
+ if custom_binning:
+ bins = np.arange(1, num_bins + 1) - 0.5
+ xticks = np.arange(1, num_bins)
+
+ # General
+ hist, _, _ = plt.hist(
+ data, bins=bins, alpha=0.75, range=(1, len(counter) + 2), ec="black",
+ )
+ plt.xticks(xticks, fontsize=8)
+ plt.xlabel(kwargs.get("xlabel", ""))
+ plt.ylabel(kwargs.get("ylabel", ""))
+
+ # Percentiles
+ percentiles = kwargs.get("percentiles", None)
+ if percentiles is not None:
+ for percentile in percentiles:
+ perc = np.percentile(data, percentile)
+ plt.axvline(x=perc, color="red")
+ plt.text(
+ perc,
+ max(hist) * 1.05,
+ str(percentile),
+ color="red",
+ ha="center",
+ va="bottom",
+ )
+
+ # Save and close
+ plt.savefig(f"{alias}.pdf")
+ plt.close()
+
+
+# Too specific, make more general
+def make_plot(counter, alias, **kwargs):
+ data = sorted(counter.items(), key=lambda x: x[1], reverse=True)
+
+ total_cnt = 0
+ indices, counts, cum_counts = [], [], []
+
+ for idx, (_, cnt) in enumerate(data, start=1):
+ indices.append(idx)
+ counts.append(cnt)
+
+ total_cnt += cnt
+ cum_counts.append(total_cnt)
+
+ # plt.figure(figsize=(18, 12)) # fixed
+ plot = plt.plot(
+ indices,
+ counts,
+ alpha=0.75,
+ marker="o",
+ markersize=3,
+ markerfacecolor="none",
+ linestyle="-",
+ color="royalblue",
+ )
+
+ # Adjust for percentiles' indices
+ ax = plt.gca()
+ yticks_diff = ax.get_yticks()[1] - ax.get_yticks()[0]
+ plt.ylim(-0.5 * yticks_diff)
+ bottom_loc = -0.4 * yticks_diff
+
+ # Percentiles
+ cumulative_fraction = (np.array(cum_counts) / total_cnt) * 100
+ percentiles = kwargs.get("percentiles", None)
+ if percentiles is not None:
+ for percentile in percentiles:
+ percentile_idx = np.argmax(cumulative_fraction >= percentile)
+ plt.axvline(x=percentile_idx + 1, color="r")
+ plt.text(
+ percentile_idx + 1,
+ max(counts) * 1.05,
+ str(percentile),
+ color="r",
+ ha="center",
+ va="bottom",
+ )
+ plt.text(
+ percentile_idx + 5 * len(str(percentile_idx)), # 15-20
+ bottom_loc,
+ str(percentile_idx),
+ color="r",
+ ha="left",
+ fontsize=8,
+ )
+
+ plt.xlabel(kwargs.get("xlabel", ""))
+ plt.ylabel(kwargs.get("ylabel", ""))
+ plt.savefig(f"{alias}.pdf")
+ plt.close()
+
+
+def calc_bpe_statistics(bliss_lexicon, transcription, **kwargs):
+ try:
+ bliss_lexicon = gzip.open(bliss_lexicon, "r")
+ except Exception:
+ pass
+
+ tree = ET.parse(bliss_lexicon)
+ root = tree.getroot()
+
+ # If present, omit [blank], [SILENCE], [UNKNOWN] etc.
+ bpe_tokens = tree.findall(
+ './/phoneme-inventory/phoneme[variation="context"]/symbol'
+ )
+ bpe_vocab = [bpe_token.text for bpe_token in bpe_tokens]
+
+ # Average number of symbols that comprise a BPE token
+ # num_bpe_tokens = len(bpe_vocab), avoid redundant computations
+ total_num_symbols = 0
+ symbols_per_token = Counter()
+ num_symbols_per_token = Counter()
+
+ for bpe_token in bpe_vocab:
+ curr_num_symbols = len(bpe_token.replace("@@", ""))
+ symbols_per_token.update({bpe_token: curr_num_symbols})
+ num_symbols_per_token.update([curr_num_symbols])
+ total_num_symbols += curr_num_symbols
+
+ # Mean (token-level statistics)
+ mean_num_symbols_per_token = total_num_symbols / len(bpe_vocab)
+ filename = "mean_num_symbols_per_token.txt"
+ with open(filename, "w+") as f:
+ f.write(str(float(mean_num_symbols_per_token)))
+
+ # # Histogram (token-level statistics)
+ # data = [*num_symbols_per_token.elements()]
+ # num_bins = len(num_symbols_per_token.keys()) + 1 # capture full range of values
+ #
+ # # General
+ # plt.hist(data, bins=np.arange(1, num_bins+1) - 0.5, alpha=0.75, range=(1, num_bins+1), ec='black')
+ # plt.xticks(np.arange(1,num_bins), fontsize=8)
+ # plt.xlabel("Number of symbols per BPE token")
+ # plt.ylabel("Number of occurrences")
+ #
+ # # 95-th quantile w/ disambiguation
+ # percentile = np.percentile(data, 95)
+ # plt.axvline(x=percentile, color='red')
+ # plt.text(percentile, max(hist) + 40, '95', color='red', ha='center', va='bottom')
+ #
+ # plt.savefig("num_symbols_per_token_histogram.pdf")
+ # plt.close()
+
+ # Histogram (token-level statistics)
+ kwargs.update(
+ {
+ "xlabel": "Number of symbols per BPE token",
+ "ylabel": "Number of occurrences",
+ "percentiles": [95],
+ }
+ )
+ make_histogram(num_symbols_per_token, "num_symbols_per_token_histogram", **kwargs)
+
+ # --------------------------------- #
+
+ # Word-level statistics - iterate over tokenization-dict
+ tokenization = {}
+ # Take care to omit special lemmata
+ for lemma in root.findall(".//lemma"):
+ if not lemma.attrib:
+ orth = lemma.find(".//orth")
+ phon = lemma.find(".//phon")
+
+ # Only works for unique sequences
+ if orth is not None and phon is not None:
+ tokenization[orth.text] = phon.text.split()
+
+ total_num_tokens = 0 # num_words = len(tokenization), avoid redundant computations
+ token_count_per_vocab = Counter()
+ num_tokens_per_word = Counter()
+
+ for word, subwords in tokenization.items():
+ curr_num_tokens = len(subwords)
+ num_tokens_per_word.update([curr_num_tokens])
+ total_num_tokens += curr_num_tokens
+ for subword in subwords:
+ token_count_per_vocab.update([subword])
+
+ # Means (word-level/vocab-level statistics)
+ mean_num_token_per_word = total_num_tokens / len(tokenization) # number of words
+ filename = "mean_num_token_per_word.txt"
+ with open(filename, "w+") as f:
+ f.write(str(float(mean_num_token_per_word)))
+
+ mean_token_count_per_vocab = total_num_tokens / len(
+ token_count_per_vocab
+ ) # number of BPE subwords
+ filename = "mean_token_count_per_vocab.txt"
+ with open(filename, "w+") as f:
+ f.write(str(float(mean_token_count_per_vocab)))
+
+ # Dump BPE token frequency (vocabulary)
+ filename = "token_count_per_vocab.json"
+ with open(filename, "w+") as f:
+ json.dump(token_count_per_vocab, f, indent=0)
+
+ # Visualisations (word-level statistics)
+ kwargs.update(
+ {
+ "xlabel": "Number of BPE tokens per word",
+ # 'ylabel': "Number of occurrences",
+ "percentiles": [95],
+ }
+ )
+ make_histogram(num_tokens_per_word, "num_token_per_word_histogram", **kwargs)
+
+ kwargs.update(
+ {
+ "xlabel": "BPE tokens' counts (vocabulary)",
+ # 'ylabel': "Number of occurrences",
+ "percentiles": [90, 95, 99],
+ }
+ )
+ make_plot(token_count_per_vocab, "token_count_per_vocab_plot", **kwargs)
+
+ # --------------------------------- #
+
+ # Sequence-level statistics - iterate over transcription (list of text sequences)
+ total_num_tokens = 0
+ total_num_sequences = 0
+ oov_words = Counter()
+ token_count_per_corpus = Counter()
+ num_tokens_per_sequence = Counter()
+
+ with open(transcription, "rt") as f:
+ for sequence in f:
+ curr_num_tokens = 0
+
+ for word in sequence.split():
+ subwords = tokenization.get(word, [])
+ if subwords:
+ curr_num_tokens += len(subwords)
+
+ for subword in subwords:
+ token_count_per_corpus.update([subword])
+ else:
+ oov_words.update([word])
+
+ num_tokens_per_sequence.update(
+ [curr_num_tokens]
+ ) # num_tokens per given sequence
+ total_num_tokens += curr_num_tokens
+ total_num_sequences += 1
+
+ # Means (transcription-level statistics)
+ mean_num_token_per_sequence = total_num_tokens / total_num_sequences
+ filename = "mean_num_token_per_sequence.txt"
+ with open(filename, "w+") as f:
+ f.write(str(float(mean_num_token_per_sequence)))
+
+ mean_token_count_per_corpus = total_num_tokens / len(
+ token_count_per_corpus
+ ) # number of BPE subwords
+ filename = "mean_token_count_per_corpus.txt"
+ with open(filename, "w+") as f:
+ f.write(str(float(mean_token_count_per_corpus)))
+
+ # Dump BPE token frequency (corpus)
+ filename = "token_count_per_corpus.json"
+ with open(filename, "w+") as f:
+ json.dump(token_count_per_corpus, f, indent=0)
+
+ # Dump OOV words (words without BPE-tokenization in lexicon)
+ filename = "oov_words.json"
+ with open(filename, "w+") as f:
+ json.dump(oov_words, f, indent=0)
+
+ # Visualisations (transcription-level statistics)
+ kwargs.update(
+ {
+ "xlabel": "Number of BPE tokens per sequence",
+ # 'ylabel': "Number of occurrences",
+ "percentiles": [95],
+ }
+ )
+ make_histogram(
+ num_tokens_per_sequence, "num_token_per_sequence_histogram", **kwargs
+ )
+
+ kwargs.update(
+ {
+ "xlabel": "BPE tokens' counts (corpus)",
+ # 'ylabel': "Number of occurrences",
+ "percentiles": [90, 95, 99],
+ }
+ )
+ make_plot(token_count_per_corpus, "token_count_per_corpus_plot", **kwargs)
+
+ # --------------------------------- #
+
+ filename = "bpe_statistics.txt"
+ with open(filename, "w+") as f:
+ f.write("BPE STATISTICS:\n")
+ f.write(
+ f"\t Mean number of symbols per BPE token: {mean_num_symbols_per_token}.\n\n"
+ )
+ f.write(f"\t Mean number of BPE tokens per word: {mean_num_token_per_word}.\n")
+ f.write(
+ f"\t Mean number of BPE tokens per sequence: {mean_num_token_per_sequence}.\n\n"
+ )
+ f.write(
+ f"\t Mean count of BPE tokens in vocabulary: {mean_token_count_per_vocab}.\n"
+ )
+ f.write(
+ f"\t Mean count of BPE tokens in corpus: {mean_token_count_per_corpus}.\n"
+ )
+
+
+def main():
+ arg_parser = argparse.ArgumentParser(
+ description="Calculate BPE subword statistics."
+ )
+ arg_parser.add_argument(
+ "bliss_lexicon", help="Bliss lexicon with word-to-tokenization correspondence."
+ )
+ arg_parser.add_argument(
+ "transcription", help="Corpus text corresponding to a Bliss corpus."
+ )
+ args = arg_parser.parse_args()
+
+ # TODO
+ hist_kwargs = {"xlabel": None, "ylabel": None, "percentile": [95]}
+ calc_bpe_statistics(args.bliss_lexicon, args.transcription, **hist_kwargs)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/users/gruev/tools/schmitt_segment_statistics.py b/users/gruev/tools/schmitt_segment_statistics.py
new file mode 100644
index 000000000..6ff0769e1
--- /dev/null
+++ b/users/gruev/tools/schmitt_segment_statistics.py
@@ -0,0 +1,264 @@
+import argparse
+import json
+import sys
+import numpy as np
+from collections import Counter
+import matplotlib.pyplot as plt
+
+def calc_segment_stats_with_sil(blank_idx, sil_idx):
+
+ dataset.init_seq_order()
+ seq_idx = 0
+ inter_sil_seg_len = 0
+ init_sil_seg_len = 0
+ final_sil_seg_len = 0
+ label_seg_len = 0
+ num_blank_frames = 0
+ num_label_segs = 0
+ num_sil_segs = 0
+ num_init_sil_segs = 0
+ num_final_sil_segs = 0
+ num_seqs = 0
+ max_seg_len = 0
+
+ label_dependent_seg_lens = Counter()
+ label_dependent_num_segs = Counter()
+
+ map_non_sil_seg_len_to_count = Counter()
+ map_sil_seg_len_to_count = Counter()
+
+ while dataset.is_less_than_num_seqs(seq_idx):
+ num_seqs += 1
+ # progress indication
+ if seq_idx % 1000 == 0:
+ complete_frac = dataset.get_complete_frac(seq_idx)
+ print("Progress: %.02f" % (complete_frac * 100))
+ dataset.load_seqs(seq_idx, seq_idx + 1)
+ data = dataset.get_data(seq_idx, "data")
+
+ # if data[-1] == blank_idx:
+ # print("LAST IDX IS BLANK!")
+ # print(data)
+ # print(dataset.get_tag(seq_idx))
+ # print("------------------------")
+
+ # find non-blanks and silence
+ non_blank_idxs = np.where(data != blank_idx)[0]
+ sil_idxs = np.where(data == sil_idx)[0]
+
+ non_blank_data = data[data != blank_idx]
+
+ # count number of segments and number of blank frames
+ num_label_segs += len(non_blank_idxs) - len(sil_idxs)
+ num_sil_segs += len(sil_idxs)
+ num_blank_frames += len(data) - len(non_blank_idxs)
+
+ # if there are only blanks, we skip the seq as there are no segments
+ if non_blank_idxs.size == 0:
+ seq_idx += 1
+ continue
+ else:
+ prev_idx = 0
+ try:
+ # go through non blanks and count segment len
+ # differ between sil_beginning, sil_mid, sil_end and non-sil
+ for i, idx in enumerate(non_blank_idxs):
+ seg_len = idx - prev_idx
+ # first segment is always one too short because of prev_idx = 0
+ if prev_idx == 0:
+ seg_len += 1
+
+ if seg_len > max_seg_len:
+ max_seg_len = seg_len
+
+ # if seg_len > 20:
+ # print("SEQ WITH SEG LEN OVER 20:\n")
+ # print(data)
+ # print(dataset.get_tag(seq_idx))
+ # print("-------------------------------")
+ #
+ # if seg_len > 30:
+ # print("SEQ WITH SEG LEN OVER 30:\n")
+ # print(data)
+ # print(dataset.get_tag(seq_idx))
+ # print("-------------------------------")
+ #
+ # if seg_len > 40:
+ # print("SEQ WITH SEG LEN OVER 40:\n")
+ # print(data)
+ # print(dataset.get_tag(seq_idx))
+ # print("-------------------------------")
+ #
+ # if seg_len > 60:
+ # print("SEQ WITH SEG LEN OVER 60:\n")
+ # print(data)
+ # print(dataset.get_tag(seq_idx))
+ # print("-------------------------------")
+ #
+ # if seg_len > 80:
+ # print("SEQ WITH SEG LEN OVER 80:\n")
+ # print(data)
+ # print(dataset.get_tag(seq_idx))
+ # print("-------------------------------")
+
+ label_dependent_seg_lens.update({non_blank_data[i]: seg_len})
+ label_dependent_num_segs.update([non_blank_data[i]])
+
+ if idx in sil_idxs:
+ if i == 0:
+ init_sil_seg_len += seg_len
+ num_init_sil_segs += 1
+ map_sil_seg_len_to_count.update([seg_len])
+ elif i == len(non_blank_idxs) - 1:
+ final_sil_seg_len += seg_len
+ num_final_sil_segs += 1
+ map_sil_seg_len_to_count.update([seg_len])
+ else:
+ inter_sil_seg_len += seg_len
+ map_sil_seg_len_to_count.update([seg_len])
+ else:
+ label_seg_len += seg_len
+ map_non_sil_seg_len_to_count.update([seg_len])
+
+ prev_idx = idx
+ except IndexError:
+ continue
+
+ seq_idx += 1
+
+ mean_init_sil_len = init_sil_seg_len / num_init_sil_segs if num_init_sil_segs > 0 else 0
+ mean_final_sil_len = final_sil_seg_len / num_final_sil_segs if num_final_sil_segs > 0 else 0
+ mean_inter_sil_len = inter_sil_seg_len / (num_sil_segs - num_init_sil_segs - num_final_sil_segs) if inter_sil_seg_len > 0 else 0
+ mean_total_sil_len = (init_sil_seg_len + final_sil_seg_len + inter_sil_seg_len) / num_seqs
+
+ mean_label_len = label_seg_len / num_label_segs
+ mean_total_label_len = label_seg_len / num_seqs
+
+ label_dependent_mean_seg_lens = {int(idx): label_dependent_seg_lens[idx] / label_dependent_num_segs[idx] for idx in label_dependent_seg_lens }
+ label_dependent_mean_seg_lens.update({idx: mean_label_len for idx in range(blank_idx) if idx not in label_dependent_mean_seg_lens})
+
+ mean_seq_len = (num_blank_frames + num_sil_segs + num_label_segs) / num_seqs
+
+ num_segments_shorter2 = sum([count for seg_len, count in map_non_sil_seg_len_to_count.items() if seg_len < 2])
+ num_segments_shorter4 = sum([count for seg_len, count in map_non_sil_seg_len_to_count.items() if seg_len < 4])
+ num_segments_shorter8 = sum([count for seg_len, count in map_non_sil_seg_len_to_count.items() if seg_len < 8])
+ num_segments_shorter16 = sum([count for seg_len, count in map_non_sil_seg_len_to_count.items() if seg_len < 16])
+ num_segments_shorter32 = sum([count for seg_len, count in map_non_sil_seg_len_to_count.items() if seg_len < 32])
+ num_segments_shorter64 = sum([count for seg_len, count in map_non_sil_seg_len_to_count.items() if seg_len < 64])
+ num_segments_shorter128 = sum([count for seg_len, count in map_non_sil_seg_len_to_count.items() if seg_len < 128])
+
+ filename = "statistics"
+ with open(filename, "w+") as f:
+ f.write("Segment statistics: \n\n")
+ f.write("\tSilence: \n")
+ f.write("\t\tInitial:\n")
+ f.write("\t\t\tMean length: %f \n" % mean_init_sil_len)
+ f.write("\t\t\tNum segments: %f \n" % num_init_sil_segs)
+ f.write("\t\tIntermediate:\n")
+ f.write("\t\t\tMean length: %f \n" % mean_inter_sil_len)
+ f.write("\t\t\tNum segments: %f \n" % (num_sil_segs - num_init_sil_segs - num_final_sil_segs))
+ f.write("\t\tFinal:\n")
+ f.write("\t\t\tMean length: %f \n" % mean_final_sil_len)
+ f.write("\t\t\tNum segments: %f \n" % num_final_sil_segs)
+ f.write("\t\tTotal per sequence:\n")
+ f.write("\t\t\tMean length: %f \n" % mean_total_sil_len)
+ f.write("\t\t\tNum segments: %f \n" % num_sil_segs)
+ f.write("\n")
+ f.write("\tNon-silence: \n")
+ f.write("\t\tMean length per segment: %f \n" % mean_label_len)
+ f.write("\t\tMean length per sequence: %f \n" % mean_total_label_len)
+ f.write("\t\tNum segments: %f \n" % num_label_segs)
+ f.write("\t\tPercent segments shorter than x frames: \n")
+ f.write("\t\tx = 2: %f \n" % (num_segments_shorter2 / num_label_segs))
+ f.write("\t\tx = 4: %f \n" % (num_segments_shorter4 / num_label_segs))
+ f.write("\t\tx = 8: %f \n" % (num_segments_shorter8 / num_label_segs))
+ f.write("\t\tx = 16: %f \n" % (num_segments_shorter16 / num_label_segs))
+ f.write("\t\tx = 32: %f \n" % (num_segments_shorter32 / num_label_segs))
+ f.write("\t\tx = 64: %f \n" % (num_segments_shorter64 / num_label_segs))
+ f.write("\t\tx = 128: %f \n" % (num_segments_shorter128 / num_label_segs))
+ f.write("\n")
+ f.write("Overall maximum segment length: %d \n" % max_seg_len)
+ f.write("\n")
+ f.write("\n")
+ f.write("Sequence statistics: \n\n")
+ f.write("\tMean length: %f \n" % mean_seq_len)
+ f.write("\tNum sequences: %f \n" % num_seqs)
+
+ filename = "mean_non_sil_len"
+ with open(filename, "w+") as f:
+ f.write(str(float(mean_label_len)))
+
+ filename = "label_dep_mean_lens"
+ with open(filename, "w+") as f:
+ json.dump(label_dependent_mean_seg_lens, f)
+
+ # plot histograms non-sil segment lens
+ hist_data = [item for seg_len, count in map_non_sil_seg_len_to_count.items() for item in [seg_len] * count]
+ plt.hist(hist_data, bins=30, range=(0, 50))
+ ax = plt.gca()
+ quantiles = [np.quantile(hist_data, q) for q in [.90, .95, .99]]
+ for n, q in zip([90, 95, 99], quantiles):
+ # write quantiles to file
+ with open("percentile_%s" % n, "w+") as f:
+ f.write(str(q))
+ ax.axvline(q, color="r")
+ plt.savefig("non_sil_histogram.pdf")
+ plt.close()
+
+ # plot histograms sil segment lens
+ hist_data = [item for seg_len, count in map_sil_seg_len_to_count.items() for item in [seg_len] * count]
+ plt.hist(hist_data, bins=40, range=(0, 100))
+ if len(hist_data) != 0:
+ ax = plt.gca()
+ quantiles = [np.quantile(hist_data, q) for q in [.90, .95, .99]]
+ for q in quantiles:
+ ax.axvline(q, color="r")
+ plt.savefig("sil_histogram.pdf")
+ plt.close()
+
+
+def init(hdf_file, seq_list_filter_file):
+ rnn.init_better_exchook()
+ rnn.init_thread_join_hack()
+ dataset_dict = {
+ "class": "HDFDataset", "files": [hdf_file], "use_cache_manager": True, "seq_list_filter_file": seq_list_filter_file
+ }
+
+ rnn.init_config(config_filename=None, default_config={"cache_size": 0})
+ global config
+ config = rnn.config
+ config.set("log", None)
+ global dataset
+ dataset = rnn.init_dataset(dataset_dict)
+ rnn.init_log()
+ print("Returnn segment-statistics starting up", file=rnn.log.v2)
+ rnn.returnn_greeting()
+ rnn.init_faulthandler()
+ rnn.init_config_json_network()
+
+
+def main():
+ arg_parser = argparse.ArgumentParser(description="Calculate segment statistics.")
+ arg_parser.add_argument("hdf_file", help="hdf file which contains the extracted alignments of some corpus")
+ arg_parser.add_argument("--seq-list-filter-file", help="whitelist of sequences to use", default=None)
+ arg_parser.add_argument("--blank-idx", help="the blank index in the alignment", default=0, type=int)
+ arg_parser.add_argument("--sil-idx", help="the blank index in the alignment", default=None)
+ arg_parser.add_argument("--returnn-root", help="path to returnn root")
+ args = arg_parser.parse_args()
+ sys.path.insert(0, args.returnn_root)
+ global rnn
+ import returnn.__main__ as rnn
+
+ init(args.hdf_file, args.seq_list_filter_file)
+
+ try:
+ calc_segment_stats_with_sil(args.blank_idx, args.sil_idx)
+ except KeyboardInterrupt:
+ print("KeyboardInterrupt")
+ sys.exit(1)
+ finally:
+ rnn.finalize()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/users/gruev/tools/segment_statistics.py b/users/gruev/tools/segment_statistics.py
new file mode 100755
index 000000000..c88a8f760
--- /dev/null
+++ b/users/gruev/tools/segment_statistics.py
@@ -0,0 +1,270 @@
+import argparse
+import json
+import sys
+import numpy as np
+from collections import Counter
+import matplotlib.pyplot as plt
+
+# This works if blank_idx == 0
+def calc_ctc_segment_stats(num_labels, blank_idx=0):
+ dataset.init_seq_order()
+
+ # HDF Dataset Utterance (=Sequence) Iteration
+ # In the end, total_num_seq := (curr_seq_idx + 1)
+ curr_seq_idx = 0
+
+ # blank and non-blank frame cnt (across utterances)
+ num_blank_frames, num_label_frames = 0, 0
+
+ # start-blank, end-blank, mid-blank (and non-blank)
+ # individual counts (per sequence) and number of such segments
+ # start_blank_cnt, mid_blank_cnt, end_blank_cnt = 0, 0, 0
+ # start_blank_seg_cnt, end_blank_seg_cnt = 0, 0
+
+ # Maximal segment duration
+ max_seg_len = 0
+
+ # Counter for segment duration indexed by label idx
+ label_seg_lens = Counter()
+ # Counter for number of label idx occurrences
+ label_seg_freq = Counter()
+ # Counter for number label segment duration occurrences
+ label_seg_len_freq = Counter()
+
+ # # Counter for blank segments duration
+ # blank_seg_len = Counter()
+ # # Counter for non-blank segments duration
+ # non_blank_seg_len = Counter()
+
+ while dataset.is_less_than_num_seqs(curr_seq_idx):
+ dataset.load_seqs(curr_seq_idx, curr_seq_idx + 1)
+ data = dataset.get_data(curr_seq_idx, "data")
+
+ # Filter [blank] only (no [SILENCE])
+ blanks_idx = np.where(data == blank_idx)[0]
+ labels_idx = np.where(data != blank_idx)[0]
+ labels = data[labels_idx]
+
+ # Counts for [blank] and labels (non-blanks)
+ num_blank_frames += len(blanks_idx)
+ num_label_frames += len(labels_idx)
+
+ # If there are only blanks, skip the current sequence
+ if len(labels_idx) == 0:
+ curr_seq_idx += 1
+ continue
+
+ # Frame duration between non-blank indices
+ curr_seq_seg_len = np.diff(labels_idx, prepend=-1)
+
+ # Update current max segment duration per sequence
+ curr_seq_max_seg_len = np.max(curr_seq_seg_len)
+ if curr_seq_max_seg_len > max_seg_len:
+ max_seg_len = curr_seq_max_seg_len
+
+ # Update number of label idx occurrences
+ label_seg_freq.update(labels)
+
+ # Update number of label segment duration occurrences
+ label_seg_len_freq.update(curr_seq_seg_len)
+
+ # Update durations of label segments
+ for label, seg_len in zip(labels, curr_seq_seg_len):
+ label_seg_lens.update({label: seg_len})
+
+ # # Compute initial, intermediate, final blank frame counts
+ # first_non_blank, last_non_blank = labels_idx[0], labels_idx[-1]
+ # if first_non_blank > 0:
+ # start_blank_cnt += labels_idx[0]
+ # start_blank_seg_cnt += 1
+ # if last_non_blank < len(data) - 1:
+ # end_blank_cnt += len(data) - labels_idx[-1] - 1
+ # end_blank_seg_cnt += 1
+ # mid_blank_cnt = np.count_nonzero(data == 0) - start_blank_cnt - end_blank_cnt
+
+ curr_seq_idx += 1
+
+ # total_seg_cnt = sum(label_seg_freq.values()) # == num_label_frames
+
+ # label_seg_lens holds the corresponding segment lengths
+ total_seg_len_1 = sum(label_seg_lens.values())
+ total_seg_len_2 = sum([k * v for (k, v) in label_seg_len_freq.items()])
+ assert total_seg_len_1 == total_seg_len_2
+
+ total_seg_len = total_seg_len_1
+
+ # Mean label length per segment
+ mean_label_seg_len = total_seg_len / num_label_frames
+ # Mean label length per sequence
+ mean_label_seq_len = total_seg_len / curr_seq_idx # =: num_seqs
+
+ # Mean duration of segment for each label
+ mean_label_seg_lens = {}
+
+ for idx in range(num_labels):
+ if label_seg_freq[idx] == 0:
+ mean_label_seg_lens[idx] = 0
+ else:
+ mean_label_seg_lens[idx] = label_seg_lens[idx] / label_seg_freq[idx]
+
+ # Mean sequence length
+ mean_seq_len = (num_blank_frames + num_label_frames) / curr_seq_idx
+
+ # Length statistics of label segments
+ num_seg_lt2 = sum(
+ [
+ count
+ for label_seg_len, count in label_seg_len_freq.items()
+ if label_seg_len < 2
+ ]
+ )
+ num_seg_lt4 = sum(
+ [
+ count
+ for label_seg_len, count in label_seg_len_freq.items()
+ if label_seg_len < 4
+ ]
+ )
+ num_seg_lt8 = sum(
+ [
+ count
+ for label_seg_len, count in label_seg_len_freq.items()
+ if label_seg_len < 8
+ ]
+ )
+ num_seg_lt16 = sum(
+ [
+ count
+ for label_seg_len, count in label_seg_len_freq.items()
+ if label_seg_len < 16
+ ]
+ )
+ num_seg_lt32 = sum(
+ [
+ count
+ for label_seg_len, count in label_seg_len_freq.items()
+ if label_seg_len < 32
+ ]
+ )
+ num_seg_lt64 = sum(
+ [
+ count
+ for label_seg_len, count in label_seg_len_freq.items()
+ if label_seg_len < 64
+ ]
+ )
+ num_seg_lt128 = sum(
+ [
+ count
+ for label_seg_len, count in label_seg_len_freq.items()
+ if label_seg_len < 128
+ ]
+ )
+
+ # Overview of computed statistics
+ filename = "statistics.txt"
+ with open(filename, "w+") as f:
+ f.write("\tNon-silence: \n")
+ f.write("\t\tMean length per segment: %f \n" % mean_label_seg_len)
+ f.write("\t\tMean length per sequence: %f \n" % mean_label_seq_len)
+ f.write("\t\tNum segments: %f \n" % num_label_frames)
+ f.write("\t\tPercent segments shorter than x frames: \n")
+ f.write("\t\tx = 2: %f \n" % (num_seg_lt2 / num_label_frames))
+ f.write("\t\tx = 4: %f \n" % (num_seg_lt4 / num_label_frames))
+ f.write("\t\tx = 8: %f \n" % (num_seg_lt8 / num_label_frames))
+ f.write("\t\tx = 16: %f \n" % (num_seg_lt16 / num_label_frames))
+ f.write("\t\tx = 32: %f \n" % (num_seg_lt32 / num_label_frames))
+ f.write("\t\tx = 64: %f \n" % (num_seg_lt64 / num_label_frames))
+ f.write("\t\tx = 128: %f \n" % (num_seg_lt128 / num_label_frames))
+
+ f.write("\n")
+ f.write("Overall maximum segment length: %d \n" % max_seg_len)
+ f.write("\n\n")
+
+ f.write("Sequence statistics: \n\n")
+ f.write("\tMean length: %f \n" % mean_seq_len)
+ f.write("\tNum sequences: %f \n" % curr_seq_idx)
+
+ filename = "mean_label_seq_len.txt"
+ with open(filename, "w+") as f:
+ f.write(str(float(mean_label_seq_len)))
+
+ filename = "mean_label_seg_len.txt"
+ with open(filename, "w+") as f:
+ f.write(str(float(mean_label_seg_len)))
+
+ filename = "mean_label_seg_lens.json"
+ with open(filename, "w+") as f:
+ json.dump(mean_label_seg_lens, f)
+
+ # Histogram for label segment lengths
+ hist_data = [
+ item
+ for label_seg_len, count in label_seg_len_freq.items()
+ for item in [label_seg_len] * count
+ ]
+ plt.hist(hist_data, bins=40, range=(0, 40))
+ ax = plt.gca()
+ quantiles = [np.quantile(hist_data, q) for q in [0.90, 0.95, 0.99]]
+ for n, q in zip([90, 95, 99], quantiles):
+ # Write quantiles to files
+ with open("quantile_%s" % n, "w+") as f:
+ f.write(str(q))
+ ax.axvline(q, color="r")
+ plt.savefig("labels_histogram.pdf")
+ plt.close()
+
+
+def init(hdf_file, seq_list_filter_file):
+ rnn.init_better_exchook()
+ rnn.init_thread_join_hack()
+ dataset_dict = {
+ "class": "HDFDataset",
+ "files": [hdf_file],
+ "use_cache_manager": True,
+ "seq_list_filter_file": seq_list_filter_file,
+ }
+
+ rnn.init_config(config_filename=None, default_config={"cache_size": 0})
+ global config
+ config = rnn.config
+ config.set("log", None)
+ global dataset
+ dataset = rnn.init_dataset(dataset_dict)
+ rnn.init_log()
+ print("Returnn segment-statistics starting up...", file=rnn.log.v2)
+ rnn.returnn_greeting()
+ rnn.init_faulthandler()
+ rnn.init_config_json_network()
+
+
+def main():
+ arg_parser = argparse.ArgumentParser(description="Calculate alignment statistics.")
+ arg_parser.add_argument(
+ "hdf_file",
+ help="hdf file which contains the extracted alignments of some corpus",
+ )
+ arg_parser.add_argument(
+ "--seq-list-filter-file", help="whitelist of sequences to use", default=None
+ )
+ arg_parser.add_argument(
+ "--blank-idx", help="the blank index in the alignment", default=0, type=int
+ )
+ arg_parser.add_argument(
+ "--num-labels", help="the total number of labels in the alignment", type=int
+ )
+ arg_parser.add_argument("--returnn-root", help="path to RETURNN root")
+
+ args = arg_parser.parse_args()
+ sys.path.insert(0, args.returnn_root)
+
+ global rnn
+ import returnn.__main__ as rnn
+
+ init(args.hdf_file, args.seq_list_filter_file)
+ calc_ctc_segment_stats(args.num_labels, args.blank_idx)
+ rnn.finalize()
+
+
+if __name__ == "__main__":
+ main()