diff --git a/users/gruev/lexicon/bpe_lexicon.py b/users/gruev/lexicon/bpe_lexicon.py new file mode 100644 index 000000000..14d7fdd77 --- /dev/null +++ b/users/gruev/lexicon/bpe_lexicon.py @@ -0,0 +1,333 @@ +__all__ = ["CreateBPELexiconJob", "ApplyBPEToTextJob"] + +import os +import sys +import shutil +import tempfile +import subprocess as sp + +import xml.etree.ElementTree as ET +from typing import Optional + +import i6_core.util as util +from i6_core.lib.lexicon import Lexicon, Lemma +from sisyphus import tk, gs, Task, Job + + +class CreateBPELexiconJob(Job): + """ + Create a Bliss lexicon from bpe transcriptions. + """ + + def __init__( + self, + base_lexicon_path, + bpe_codes, + bpe_vocab, + subword_nmt_repo=None, + unk_label="UNK", + ): + """ + :param Path base_lexicon_path: + :param Path bpe_codes: + :param Path|None bpe_vocab: + :param Path|str|None subword_nmt_repo: + """ + self.base_lexicon_path = base_lexicon_path + self.bpe_codes = bpe_codes + self.bpe_vocab = bpe_vocab + self.subword_nmt_repo = ( + subword_nmt_repo if subword_nmt_repo is not None else gs.SUBWORD_NMT_PATH + ) + self.unk_label = unk_label + + self.out_lexicon = self.output_path("lexicon.xml.gz", cached=True) + + def tasks(self): + yield Task("run", resume="run", mini_task=True) + + def run(self): + lexicon = Lexicon() + + lm_tokens = set() + + base_lexicon = Lexicon() + base_lexicon.load(self.base_lexicon_path) + for l in base_lexicon.lemmata: + if l.special is None: + for orth in l.orth: + lm_tokens.add(orth) + for token in l.synt or []: # l.synt can be None + lm_tokens.add(token) + for eval in l.eval: + for t in eval: + lm_tokens.add(t) + + lm_tokens = list(lm_tokens) + + with util.uopen("words", "wt") as f: + for t in lm_tokens: + f.write(f"{t}\n") + + vocab = set() + with util.uopen(self.bpe_vocab.get_path(), "rt") as f, util.uopen("fake_count_vocab.txt", "wt") as vocab_file: + for line in f: + if "{" in line or "" in line or "" in line or "}" in line: + continue + symbol = line.split(":")[0][1:-1] + if symbol != self.unk_label: + vocab_file.write(symbol + " -1\n") + symbol = symbol.replace(".", "_") + vocab.add(symbol) + lexicon.add_phoneme(symbol.replace(".", "_")) + + apply_binary = os.path.join( + tk.uncached_path(self.subword_nmt_repo), "apply_bpe.py" + ) + args = [ + sys.executable, + apply_binary, + "--input", + "words", + "--codes", + self.bpe_codes.get_path(), + "--vocabulary", + "fake_count_vocab.txt", + "--output", + "bpes", + ] + sp.run(args, check=True) + + with util.uopen("bpes", "rt") as f: + bpe_tokens = [l.strip() for l in f] + + w2b = {w: b for w, b in zip(lm_tokens, bpe_tokens)} + + for w, b in w2b.items(): + b = " ".join( + [token if token in vocab else self.unk_label for token in b.split()] + ) + lexicon.add_lemma(Lemma([w], [b.replace(".", "_")])) + + elem = lexicon.to_xml() + tree = ET.ElementTree(elem) + util.write_xml(self.out_lexicon.get_path(), tree) + + +# class CreateBPELexiconJob(Job): +# """ +# Create a Bliss lexicon from BPE transcriptions. +# """ +# +# def __init__( +# self, +# base_lexicon_path: tk.Path, +# bpe_codes: tk.Path, +# bpe_vocab: tk.Path, +# subword_nmt_repo: Optional[tk.Path] = None, +# unk_label: str = "[UNKNOWN]", +# add_silence: bool = True, +# add_other_special: bool = False, +# ): +# """ +# :param tk.Path base_lexicon_path: path to a Bliss lexicon +# :param tk.Path bpe_codes: path to BPE codes produced by i.e. ReturnnTrainBPEJob +# :param tk.Path bpe_vocab: path to BPE vocab produced by i.e. ReturnnTrainBPEJob +# :param tk.Path|None subword_nmt_repo: +# :param str unk_label: +# :param bool add_silence: explicitly include a [SILENCE] phoneme and lemma +# :param bool add_other_special: explicitly include special lemmata from base_lexicon_path +# """ +# +# self.base_lexicon_path = base_lexicon_path +# self.bpe_codes = bpe_codes +# self.bpe_vocab = bpe_vocab +# self.subword_nmt_repo = ( +# subword_nmt_repo if subword_nmt_repo is not None else gs.SUBWORD_NMT_PATH +# ) +# self.unk_label = unk_label +# self.add_silence = add_silence +# self.add_other_special = add_other_special +# +# self.out_lexicon = self.output_path("lexicon.xml.gz", cached=True) +# +# def tasks(self): +# yield Task("run", resume="run", mini_task=True) +# +# def run(self): +# lm_tokens = set() +# other_special = [] +# +# base_lexicon = Lexicon() +# base_lexicon.load(self.base_lexicon_path) +# +# for l in base_lexicon.lemmata: +# if l.special: +# if l.special not in ["silence", "unknown"]: +# other_special.append(l) +# continue +# for orth in l.orth or []: # l.orth can be None +# lm_tokens.add(orth) +# for token in l.synt or []: # l.synt can be None +# lm_tokens.add(token) +# for eval in l.eval or []: # l.eval can be None +# for t in eval: +# lm_tokens.add(t) +# +# lm_tokens = [lt for lt in lm_tokens if lt != ''] # catch , or +# +# with util.uopen("words", "wt") as f: +# for t in lm_tokens: +# f.write(f"{t}\n") +# +# vocab = set() +# lexicon = Lexicon() +# +# lexicon.add_phoneme(self.unk_label, variation="none") +# +# if self.add_silence: +# lexicon.add_phoneme("[SILENCE]", variation="none") +# +# with util.uopen(self.bpe_vocab.get_path(), "rt") as bpe_vocab_file: +# with util.uopen("fake_count_vocab", "wt") as fake_count_file: +# for line in bpe_vocab_file: +# if "{" in line or "" in line or "" in line or "}" in line: +# continue +# symbol = line.split(":")[0][1:-1] +# if symbol != self.unk_label: +# fake_count_file.write(symbol + " -1\n") +# symbol = symbol.replace(".", "_") +# vocab.add(symbol) +# lexicon.add_phoneme(symbol) +# +# apply_binary = os.path.join( +# tk.uncached_path(self.subword_nmt_repo), "apply_bpe.py" +# ) +# args = [ +# sys.executable, +# apply_binary, +# "--input", +# "words", +# "--codes", +# self.bpe_codes.get_path(), +# "--vocabulary", +# "fake_count_vocab", +# "--output", +# "bpes", +# ] +# sp.run(args, check=True) +# +# with util.uopen("bpes", "rt") as f: +# bpe_tokens = [l.strip() for l in f] +# +# w2b = {w: b for w, b in zip(lm_tokens, bpe_tokens)} +# +# lexicon.add_lemma( +# Lemma(["[UNKNOWN]"], [self.unk_label], None, None, special="unknown") +# ) +# +# if self.add_silence: +# lexicon.add_lemma( +# Lemma(["[SILENCE]"], ["[SILENCE]"], [], [[]], special="silence") +# ) +# +# if self.add_other_special: +# for l in other_special: +# lexicon.add_lemma(l) +# +# for w, b in w2b.items(): +# b = " ".join( +# [token if token in vocab else self.unk_label for token in b.split()] +# ) +# lexicon.add_lemma(Lemma([w], [b.replace(".", "_")])) +# +# elem = lexicon.to_xml() +# tree = ET.ElementTree(elem) +# util.write_xml(self.out_lexicon.get_path(), tree) + + +class ApplyBPEToTextJob(Job): + """ + Apply BPE codes on a text file + """ + + __sis_hash_exclude__ = {"gzip_output": False} + + def __init__( + self, + words_file: tk.Path, + bpe_codes: tk.Path, + bpe_vocab: tk.Path, + subword_nmt_repo: Optional[tk.Path] = None, + gzip_output: bool = False, + mini_task: bool = True, + ): + """ + :param tk.Path text_file: path to a words text file + :param tk.Path bpe_codes: path to BPE codes file, use e.g. ReturnnTrainBpeJob.out_bpe_codes + :param tk.Path bpe_vocab: path to BPE vocab file used to revert merge operations that produce OOV, + use e.g. ReturnnTrainBPEJob.out_bpe_vocab; + :param tk.Path/None subword_nmt_repo: path to subword nmt repository , see also `CloneGitRepositoryJob` + :param bool gzip_output: use gzip on the output text + :param bool mini_task: if the Job should run locally, e.g. only a small (<1M lines) text should be processed + """ + self.words_file = words_file + self.bpe_codes = bpe_codes + self.bpe_vocab = bpe_vocab + self.subword_nmt_repo = ( + subword_nmt_repo if subword_nmt_repo is not None else gs.SUBWORD_NMT_PATH + ) + self.gzip_output = gzip_output + + self.out_bpe_text = self.output_path("words_to_bpe.txt.gz" if gzip_output else "words_to_bpe.txt") + + self.mini_task = mini_task + self.rqmt = {"cpu": 1, "mem": 2, "time": 2} + + def tasks(self): + if self.mini_task: + yield Task("run", mini_task=True) + else: + yield Task("run", rqmt=self.rqmt) + + def run(self): + with tempfile.TemporaryDirectory(prefix=gs.TMP_PREFIX) as tmp: + words_file = self.words_file.get_path() + tmp_outfile = os.path.join(tmp, "out_text.txt") + + with util.uopen(self.bpe_vocab.get_path(), "rt") as bpe_vocab_file: + with util.uopen("fake_count_vocab.txt", "wt") as fake_count_file: + for line in bpe_vocab_file: + if "{" in line or "<" in line or "[" in line or "]" in line or ">" in line or "}" in line: + continue + symbol = line.split(":")[0][1:-1] + fake_count_file.write(symbol + " -1\n") + + apply_binary = os.path.join( + tk.uncached_path(self.subword_nmt_repo), "apply_bpe.py" + ) + cmd = [ + sys.executable, + apply_binary, + "--input", + words_file, + "--codes", + self.bpe_codes.get_path(), + "--vocabulary", + "fake_count_vocab.txt", + "--output", + tmp_outfile, + ] + util.create_executable("apply_bpe.sh", cmd) + sp.run(cmd, check=True) + + if self.gzip_output: + with util.uopen(tmp_outfile, "rt") as fin, util.uopen(self.out_bpe_text, "wb") as fout: + sp.call(["gzip"], stdin=fin, stdout=fout) + else: + shutil.copy(tmp_outfile, self.out_bpe_text.get_path()) + + @classmethod + def hash(cls, parsed_args): + del parsed_args["mini_task"] + return super().hash(parsed_args) diff --git a/users/gruev/statistics/alignment.py b/users/gruev/statistics/alignment.py new file mode 100644 index 000000000..738783658 --- /dev/null +++ b/users/gruev/statistics/alignment.py @@ -0,0 +1,102 @@ +import subprocess +import tempfile +import shutil +import os +import json + +from recipe.i6_core.util import create_executable + +from sisyphus import * + +import recipe.i6_private.users.gruev.tools as tools_mod +tools_dir = os.path.dirname(tools_mod.__file__) + + +class AlignmentStatisticsJob(Job): + def __init__( + self, + alignment, + num_labels, + blank_idx=0, + seq_list_filter_file=None, + time_rqmt=2, + returnn_python_exe=None, + returnn_root=None, + ): + + self.returnn_python_exe = ( + returnn_python_exe + if returnn_python_exe is not None + else gs.RETURNN_PYTHON_EXE + ) + self.returnn_root = ( + returnn_root if returnn_root is not None else gs.RETURNN_ROOT + ) + + self.alignment = alignment + self.seq_list_filter_file = seq_list_filter_file + self.num_labels = num_labels + self.blank_idx = blank_idx + + self.out_statistics = self.output_path("statistics.txt") + self.out_labels_hist = self.output_path("labels_histogram.pdf") + self.out_mean_label_seg_lens = self.output_path("mean_label_seg_lens.json") + # self.out_mean_label_seg_lens_var = self.output_var("mean_label_seg_lens_var.json") + self.out_mean_label_seg_len = self.output_path("mean_label_seg_len.txt") + # self.out_mean_label_seg_len_var = self.output_var("mean_label_seg_len_var.txt") + self.out_mean_label_seq_len = self.output_path("mean_label_seq_len.txt") + # self.out_mean_label_seq_len_var = self.output_var("mean_label_seq_len_var.txt") + self.out_90_quantile_var = self.output_var("quantile_90") + self.out_95_quantile_var = self.output_var("quantile_95") + self.out_99_quantile_var = self.output_var("quantile_99") + + self.time_rqmt = time_rqmt + + def tasks(self): + yield Task("run", rqmt={"cpu": 1, "mem": 4, "time": self.time_rqmt}) + + def run(self): + command = [ + self.returnn_python_exe.get_path(), + os.path.join(tools_dir, "segment_statistics.py"), + self.alignment.get_path(), + "--num-labels", + str(self.num_labels), + "--blank-idx", + str(self.blank_idx), + "--returnn-root", + self.returnn_root.get_path(), + ] + + if self.seq_list_filter_file: + command += ["--seq-list-filter-file", str(self.seq_list_filter_file)] + + create_executable("rnn.sh", command) + subprocess.check_call(["./rnn.sh"]) + + # with open("mean_label_seg_lens.json", "r") as f: + # mean_label_seg_lens = json.load(f) + # mean_label_seg_lens = [ + # mean_label_seg_lens[str(idx)] for idx in range(self.num_labels) + # ] + # self.out_mean_label_seg_lens_var.set(mean_label_seg_lens) + # + # with open("mean_label_seg_len.txt", "r") as f: + # self.out_mean_label_seg_len_var.set(float(f.read())) + # + # with open("mean_label_seq_len.txt", "r") as f: + # self.out_mean_label_seq_len_var.set(float(f.read())) + + # Set quantiles + with open("quantile_90", "r") as f: + self.out_90_quantile_var.set(int(float(f.read()))) + with open("quantile_95", "r") as f: + self.out_95_quantile_var.set(int(float(f.read()))) + with open("quantile_99", "r") as f: + self.out_99_quantile_var.set(int(float(f.read()))) + + shutil.move("statistics.txt", self.out_statistics.get_path()) + shutil.move("labels_histogram.pdf", self.out_labels_hist.get_path()) + shutil.move("mean_label_seg_lens.json", self.out_mean_label_seg_lens.get_path()) + shutil.move("mean_label_seg_len.txt", self.out_mean_label_seg_len.get_path()) + shutil.move("mean_label_seq_len.txt", self.out_mean_label_seq_len.get_path()) \ No newline at end of file diff --git a/users/gruev/statistics/bpe_statistics.py b/users/gruev/statistics/bpe_statistics.py new file mode 100755 index 000000000..99fe5a782 --- /dev/null +++ b/users/gruev/statistics/bpe_statistics.py @@ -0,0 +1,122 @@ +import subprocess +import shutil +import os + +from recipe.i6_core.util import create_executable + +from sisyphus import * + +import recipe.i6_private.users.gruev.tools as tools_mod + +tools_dir = os.path.dirname(tools_mod.__file__) + + +class BpeStatisticsJob(Job): + def __init__(self, bliss_lexicon, transcription, returnn_python_exe=None): + """ + :param bliss_lexicon: Bliss lexicon with BPE subword units as sequences + :param transcription: A text file with utterance transcriptions + """ + + self.bliss_lexicon = bliss_lexicon + self.transcription = transcription + + if returnn_python_exe is not None: + self.returnn_python_exe = returnn_python_exe + else: + self.returnn_python_exe = gs.RETURNN_PYTHON_EXE + + ## Example: + # Text: AMAZINGLY COMPLICATED COMPLICATED WANAMAKER + # Segmentation: AMA@@ ZINGLY COMP@@ LIC@@ ATED COMP@@ LIC@@ ATED WAN@@ AMA@@ KER + + # Num token per sequence: 1 sequence with 11 BPE tokens --> average is 11. + # Num tokens per word: 4 words with 2 + 3 + 3 + 3 tokens --> average is 11/4 = 2.75 + # Num symbols per token: 7 tokens with 3 + 5 + 3 + 3 + 4 + 3 + 3 symbols --> average is 24/7 = 3.4 + # Token count per vocab: AMA@@ appears two times, COMP@@ and LIC@@ appear one time (per unique word!) + # Token count per corpus: AMA@@ appears two times, COMP@@ and LIC@@ apear two times (per sequence!) + + # Summary of all other statistics, register as output + self.out_bpe_statistics = self.output_path("bpe_statistics.txt") + + # Average number of BPE tokens for a sequence (utterance) + self.out_mean_num_token_per_sequence = self.output_path( + "mean_num_token_per_sequence.txt" + ) + self.out_num_token_per_sequence_histogram = self.output_path( + "num_token_per_sequence_histogram.pdf" + ) + + # Average number of BPE tokens for a single word in the corpus + self.out_mean_num_token_per_word = self.output_path( + "mean_num_token_per_word.txt" + ) + self.out_num_token_per_word_histogram = self.output_path( + "num_token_per_word_histogram.pdf" + ) + + # Average number of symbols that comprise a BPE token + self.out_mean_num_symbols_per_token = self.output_path( + "mean_num_symbols_per_token.txt" + ) + self.out_num_symbols_per_token_histogram = self.output_path( + "num_symbols_per_token_histogram.pdf" + ) + + # Number of BPE tokens per vocabulary (all words) + self.out_mean_token_count_per_vocab = self.output_path( + "mean_token_count_per_vocab.txt" + ) + self.out_token_count_per_vocab = self.output_path("token_count_per_vocab.json") + self.out_token_count_per_vocab_plot = self.output_path( + "token_count_per_vocab_plot.pdf" + ) + + # Number of BPE tokens per corpus (all sequence) + self.out_mean_token_count_per_corpus = self.output_path( + "mean_token_count_per_corpus.txt" + ) + self.out_token_count_per_corpus = self.output_path( + "token_count_per_corpus.json" + ) + self.out_token_count_per_corpus_plot = self.output_path( + "token_count_per_corpus_plot.pdf" + ) + + # OOV words in corpus + self.out_oov_words = self.output_path("oov_words.json") + + def tasks(self): + yield Task("run", rqmt={"cpu": 1, "mem": 2, "time": 2}) + + def run(self): + command = [ + self.returnn_python_exe.get_path(), + os.path.join(tools_dir, "bpe_statistics.py"), + self.bliss_lexicon.get_path(), + self.transcription.get_path(), + ] + + create_executable("rnn.sh", command) + subprocess.check_call(["./rnn.sh"]) + + # Register + shutil.move("bpe_statistics.txt", self.out_bpe_statistics.get_path()) + shutil.move("oov_words.json", self.out_oov_words.get_path()) + + for (stat, fig) in [ + ("num_token_per_sequence", "histogram"), + ("num_token_per_word", "histogram"), + ("num_symbols_per_token", "histogram"), + ("token_count_per_vocab", "plot"), + ("token_count_per_corpus", "plot"), + ]: + shutil.move( + f"mean_{stat}.txt", self.__dict__[f"out_mean_{stat}"].get_path() + ) + shutil.move( + f"{stat}_{fig}.pdf", self.__dict__[f"out_{stat}_{fig}"].get_path() + ) + + if stat in ["token_count_per_vocab", "token_count_per_corpus"]: + shutil.move(f"{stat}.json", self.__dict__[f"out_{stat}"].get_path()) diff --git a/users/gruev/statistics/schmitt_alignment.py b/users/gruev/statistics/schmitt_alignment.py new file mode 100644 index 000000000..aa10caa06 --- /dev/null +++ b/users/gruev/statistics/schmitt_alignment.py @@ -0,0 +1,94 @@ +from sisyphus import * + +from recipe.i6_core.util import create_executable +from recipe.i6_core.rasr.config import build_config_from_mapping +from recipe.i6_core.rasr.command import RasrCommand +from i6_core.returnn.config import ReturnnConfig +from i6_core.returnn.forward import ReturnnForwardJob + +from sisyphus import Path + +import subprocess +import tempfile +import shutil +import os +import json + +import recipe.i6_private.users.gruev.tools as tools_mod +# import recipe.i6_experiments.users.schmitt.tools as tools_mod +tools_dir = os.path.dirname(tools_mod.__file__) + +class AlignmentStatisticsJob(Job): + def __init__(self, alignment, seq_list_filter_file=None, blank_idx=0, silence_idx=None, + time_rqmt=2, returnn_python_exe=None, + returnn_root=None): + self.returnn_python_exe = (returnn_python_exe if returnn_python_exe is not None else gs.RETURNN_PYTHON_EXE) + self.returnn_root = (returnn_root if returnn_root is not None else gs.RETURNN_ROOT) + + self.alignment = alignment + self.seq_list_filter_file = seq_list_filter_file + self.blank_idx = blank_idx + self.silence_idx = silence_idx + self.out_statistics = self.output_path("statistics") + self.out_sil_hist = self.output_path("sil_histogram.pdf") + self.out_non_sil_hist = self.output_path("non_sil_histogram.pdf") + self.out_label_dep_stats = self.output_path("label_dep_mean_lens") + # self.out_label_dep_vars = self.output_path("label_dep_mean_vars") + self.out_label_dep_stats_var = self.output_var("label_dep_mean_lens_var", pickle=True) + # self.out_label_dep_vars_var = self.output_var("label_dep_mean_vars_var", pickle=True) + self.out_mean_non_sil_len = self.output_path("mean_non_sil_len") + self.out_mean_non_sil_len_var = self.output_var("mean_non_sil_len_var") + self.out_95_percentile_var = self.output_var("percentile_95") + self.out_90_percentile_var = self.output_var("percentile_90") + self.out_99_percentile_var = self.output_var("percentile_99") + + self.time_rqmt = time_rqmt + + def tasks(self): + yield Task("run", rqmt={"cpu": 1, "mem": 4, "time": self.time_rqmt}) + + def run(self): + command = [ + self.returnn_python_exe.get_path(), + os.path.join(tools_dir, "schmitt_segment_statistics.py"), + self.alignment.get_path(), + "--blank-idx", str(self.blank_idx), "--sil-idx", str(self.silence_idx), + "--returnn-root", self.returnn_root.get_path() + ] + + if self.seq_list_filter_file: + command += ["--seq-list-filter-file", str(self.seq_list_filter_file)] + + create_executable("rnn.sh", command) + subprocess.check_call(["./rnn.sh"]) + + with open("label_dep_mean_lens", "r") as f: + label_dep_means = json.load(f) + label_dep_means = {int(k): v for k, v in label_dep_means.items()} + label_dep_means = [label_dep_means[idx] for idx in range(len(label_dep_means)) if idx > 1] + + # with open("label_dep_vars", "r") as f: + # label_dep_vars = json.load(f) + # label_dep_vars = {int(k): v for k, v in label_dep_vars.items()} + # label_dep_vars = [label_dep_vars[idx] for idx in range(len(label_dep_vars))] + + with open("mean_non_sil_len", "r") as f: + self.out_mean_non_sil_len_var.set(float(f.read())) + + # set percentiles + with open("percentile_90", "r") as f: + self.out_90_percentile_var.set(int(float(f.read()))) + with open("percentile_95", "r") as f: + self.out_95_percentile_var.set(int(float(f.read()))) + with open("percentile_99", "r") as f: + self.out_99_percentile_var.set(int(float(f.read()))) + + self.out_label_dep_stats_var.set(label_dep_means) + # self.out_label_dep_vars_var.set(label_dep_vars) + + shutil.move("statistics", self.out_statistics.get_path()) + shutil.move("sil_histogram.pdf", self.out_sil_hist.get_path()) + shutil.move("non_sil_histogram.pdf", self.out_non_sil_hist.get_path()) + shutil.move("label_dep_mean_lens", self.out_label_dep_stats.get_path()) + # shutil.move("label_dep_vars", self.out_label_dep_vars.get_path()) + shutil.move("mean_non_sil_len", self.out_mean_non_sil_len.get_path()) diff --git a/users/gruev/tools/bpe_statistics.py b/users/gruev/tools/bpe_statistics.py new file mode 100755 index 000000000..d23352b04 --- /dev/null +++ b/users/gruev/tools/bpe_statistics.py @@ -0,0 +1,370 @@ +import sys +import json +import gzip +import argparse +import numpy as np +from collections import Counter +import matplotlib.pyplot as plt +import xml.etree.ElementTree as ET + + +# kwargs['xlabel']: Optional[str] +# kwargs['ylabel']: Optional[str] +# kwargs['percentiles']: Optinal[List[int]] + + +def num_bin_heuristic(num_values): + """ Try to partition the data into 20-30 bins. """ + num_values_rounded = np.round(num_values / 5) * 5 + + min_remainder, num_bins = 30, 0 + for n in np.arange(20, 31): + if min_remainder > num_values_rounded % n: + min_remainder = num_values_rounded % n + num_bins = n + + return num_bins + + +def make_histogram(counter, alias, **kwargs): + # Data + data = [*counter.elements()] + + max_num_bins = num_bin_heuristic(len(counter)) + num_bins = min(len(counter) + 1, max_num_bins) + custom_binning = len(counter) < max_num_bins + + # TODO: Add more kwargs for histogram settings, etc. + + # Differentiate for smaller and larger number of bins + bins = max_num_bins + xticks = None + if custom_binning: + bins = np.arange(1, num_bins + 1) - 0.5 + xticks = np.arange(1, num_bins) + + # General + hist, _, _ = plt.hist( + data, bins=bins, alpha=0.75, range=(1, len(counter) + 2), ec="black", + ) + plt.xticks(xticks, fontsize=8) + plt.xlabel(kwargs.get("xlabel", "")) + plt.ylabel(kwargs.get("ylabel", "")) + + # Percentiles + percentiles = kwargs.get("percentiles", None) + if percentiles is not None: + for percentile in percentiles: + perc = np.percentile(data, percentile) + plt.axvline(x=perc, color="red") + plt.text( + perc, + max(hist) * 1.05, + str(percentile), + color="red", + ha="center", + va="bottom", + ) + + # Save and close + plt.savefig(f"{alias}.pdf") + plt.close() + + +# Too specific, make more general +def make_plot(counter, alias, **kwargs): + data = sorted(counter.items(), key=lambda x: x[1], reverse=True) + + total_cnt = 0 + indices, counts, cum_counts = [], [], [] + + for idx, (_, cnt) in enumerate(data, start=1): + indices.append(idx) + counts.append(cnt) + + total_cnt += cnt + cum_counts.append(total_cnt) + + # plt.figure(figsize=(18, 12)) # fixed + plot = plt.plot( + indices, + counts, + alpha=0.75, + marker="o", + markersize=3, + markerfacecolor="none", + linestyle="-", + color="royalblue", + ) + + # Adjust for percentiles' indices + ax = plt.gca() + yticks_diff = ax.get_yticks()[1] - ax.get_yticks()[0] + plt.ylim(-0.5 * yticks_diff) + bottom_loc = -0.4 * yticks_diff + + # Percentiles + cumulative_fraction = (np.array(cum_counts) / total_cnt) * 100 + percentiles = kwargs.get("percentiles", None) + if percentiles is not None: + for percentile in percentiles: + percentile_idx = np.argmax(cumulative_fraction >= percentile) + plt.axvline(x=percentile_idx + 1, color="r") + plt.text( + percentile_idx + 1, + max(counts) * 1.05, + str(percentile), + color="r", + ha="center", + va="bottom", + ) + plt.text( + percentile_idx + 5 * len(str(percentile_idx)), # 15-20 + bottom_loc, + str(percentile_idx), + color="r", + ha="left", + fontsize=8, + ) + + plt.xlabel(kwargs.get("xlabel", "")) + plt.ylabel(kwargs.get("ylabel", "")) + plt.savefig(f"{alias}.pdf") + plt.close() + + +def calc_bpe_statistics(bliss_lexicon, transcription, **kwargs): + try: + bliss_lexicon = gzip.open(bliss_lexicon, "r") + except Exception: + pass + + tree = ET.parse(bliss_lexicon) + root = tree.getroot() + + # If present, omit [blank], [SILENCE], [UNKNOWN] etc. + bpe_tokens = tree.findall( + './/phoneme-inventory/phoneme[variation="context"]/symbol' + ) + bpe_vocab = [bpe_token.text for bpe_token in bpe_tokens] + + # Average number of symbols that comprise a BPE token + # num_bpe_tokens = len(bpe_vocab), avoid redundant computations + total_num_symbols = 0 + symbols_per_token = Counter() + num_symbols_per_token = Counter() + + for bpe_token in bpe_vocab: + curr_num_symbols = len(bpe_token.replace("@@", "")) + symbols_per_token.update({bpe_token: curr_num_symbols}) + num_symbols_per_token.update([curr_num_symbols]) + total_num_symbols += curr_num_symbols + + # Mean (token-level statistics) + mean_num_symbols_per_token = total_num_symbols / len(bpe_vocab) + filename = "mean_num_symbols_per_token.txt" + with open(filename, "w+") as f: + f.write(str(float(mean_num_symbols_per_token))) + + # # Histogram (token-level statistics) + # data = [*num_symbols_per_token.elements()] + # num_bins = len(num_symbols_per_token.keys()) + 1 # capture full range of values + # + # # General + # plt.hist(data, bins=np.arange(1, num_bins+1) - 0.5, alpha=0.75, range=(1, num_bins+1), ec='black') + # plt.xticks(np.arange(1,num_bins), fontsize=8) + # plt.xlabel("Number of symbols per BPE token") + # plt.ylabel("Number of occurrences") + # + # # 95-th quantile w/ disambiguation + # percentile = np.percentile(data, 95) + # plt.axvline(x=percentile, color='red') + # plt.text(percentile, max(hist) + 40, '95', color='red', ha='center', va='bottom') + # + # plt.savefig("num_symbols_per_token_histogram.pdf") + # plt.close() + + # Histogram (token-level statistics) + kwargs.update( + { + "xlabel": "Number of symbols per BPE token", + "ylabel": "Number of occurrences", + "percentiles": [95], + } + ) + make_histogram(num_symbols_per_token, "num_symbols_per_token_histogram", **kwargs) + + # --------------------------------- # + + # Word-level statistics - iterate over tokenization-dict + tokenization = {} + # Take care to omit special lemmata + for lemma in root.findall(".//lemma"): + if not lemma.attrib: + orth = lemma.find(".//orth") + phon = lemma.find(".//phon") + + # Only works for unique sequences + if orth is not None and phon is not None: + tokenization[orth.text] = phon.text.split() + + total_num_tokens = 0 # num_words = len(tokenization), avoid redundant computations + token_count_per_vocab = Counter() + num_tokens_per_word = Counter() + + for word, subwords in tokenization.items(): + curr_num_tokens = len(subwords) + num_tokens_per_word.update([curr_num_tokens]) + total_num_tokens += curr_num_tokens + for subword in subwords: + token_count_per_vocab.update([subword]) + + # Means (word-level/vocab-level statistics) + mean_num_token_per_word = total_num_tokens / len(tokenization) # number of words + filename = "mean_num_token_per_word.txt" + with open(filename, "w+") as f: + f.write(str(float(mean_num_token_per_word))) + + mean_token_count_per_vocab = total_num_tokens / len( + token_count_per_vocab + ) # number of BPE subwords + filename = "mean_token_count_per_vocab.txt" + with open(filename, "w+") as f: + f.write(str(float(mean_token_count_per_vocab))) + + # Dump BPE token frequency (vocabulary) + filename = "token_count_per_vocab.json" + with open(filename, "w+") as f: + json.dump(token_count_per_vocab, f, indent=0) + + # Visualisations (word-level statistics) + kwargs.update( + { + "xlabel": "Number of BPE tokens per word", + # 'ylabel': "Number of occurrences", + "percentiles": [95], + } + ) + make_histogram(num_tokens_per_word, "num_token_per_word_histogram", **kwargs) + + kwargs.update( + { + "xlabel": "BPE tokens' counts (vocabulary)", + # 'ylabel': "Number of occurrences", + "percentiles": [90, 95, 99], + } + ) + make_plot(token_count_per_vocab, "token_count_per_vocab_plot", **kwargs) + + # --------------------------------- # + + # Sequence-level statistics - iterate over transcription (list of text sequences) + total_num_tokens = 0 + total_num_sequences = 0 + oov_words = Counter() + token_count_per_corpus = Counter() + num_tokens_per_sequence = Counter() + + with open(transcription, "rt") as f: + for sequence in f: + curr_num_tokens = 0 + + for word in sequence.split(): + subwords = tokenization.get(word, []) + if subwords: + curr_num_tokens += len(subwords) + + for subword in subwords: + token_count_per_corpus.update([subword]) + else: + oov_words.update([word]) + + num_tokens_per_sequence.update( + [curr_num_tokens] + ) # num_tokens per given sequence + total_num_tokens += curr_num_tokens + total_num_sequences += 1 + + # Means (transcription-level statistics) + mean_num_token_per_sequence = total_num_tokens / total_num_sequences + filename = "mean_num_token_per_sequence.txt" + with open(filename, "w+") as f: + f.write(str(float(mean_num_token_per_sequence))) + + mean_token_count_per_corpus = total_num_tokens / len( + token_count_per_corpus + ) # number of BPE subwords + filename = "mean_token_count_per_corpus.txt" + with open(filename, "w+") as f: + f.write(str(float(mean_token_count_per_corpus))) + + # Dump BPE token frequency (corpus) + filename = "token_count_per_corpus.json" + with open(filename, "w+") as f: + json.dump(token_count_per_corpus, f, indent=0) + + # Dump OOV words (words without BPE-tokenization in lexicon) + filename = "oov_words.json" + with open(filename, "w+") as f: + json.dump(oov_words, f, indent=0) + + # Visualisations (transcription-level statistics) + kwargs.update( + { + "xlabel": "Number of BPE tokens per sequence", + # 'ylabel': "Number of occurrences", + "percentiles": [95], + } + ) + make_histogram( + num_tokens_per_sequence, "num_token_per_sequence_histogram", **kwargs + ) + + kwargs.update( + { + "xlabel": "BPE tokens' counts (corpus)", + # 'ylabel': "Number of occurrences", + "percentiles": [90, 95, 99], + } + ) + make_plot(token_count_per_corpus, "token_count_per_corpus_plot", **kwargs) + + # --------------------------------- # + + filename = "bpe_statistics.txt" + with open(filename, "w+") as f: + f.write("BPE STATISTICS:\n") + f.write( + f"\t Mean number of symbols per BPE token: {mean_num_symbols_per_token}.\n\n" + ) + f.write(f"\t Mean number of BPE tokens per word: {mean_num_token_per_word}.\n") + f.write( + f"\t Mean number of BPE tokens per sequence: {mean_num_token_per_sequence}.\n\n" + ) + f.write( + f"\t Mean count of BPE tokens in vocabulary: {mean_token_count_per_vocab}.\n" + ) + f.write( + f"\t Mean count of BPE tokens in corpus: {mean_token_count_per_corpus}.\n" + ) + + +def main(): + arg_parser = argparse.ArgumentParser( + description="Calculate BPE subword statistics." + ) + arg_parser.add_argument( + "bliss_lexicon", help="Bliss lexicon with word-to-tokenization correspondence." + ) + arg_parser.add_argument( + "transcription", help="Corpus text corresponding to a Bliss corpus." + ) + args = arg_parser.parse_args() + + # TODO + hist_kwargs = {"xlabel": None, "ylabel": None, "percentile": [95]} + calc_bpe_statistics(args.bliss_lexicon, args.transcription, **hist_kwargs) + + +if __name__ == "__main__": + main() diff --git a/users/gruev/tools/schmitt_segment_statistics.py b/users/gruev/tools/schmitt_segment_statistics.py new file mode 100644 index 000000000..6ff0769e1 --- /dev/null +++ b/users/gruev/tools/schmitt_segment_statistics.py @@ -0,0 +1,264 @@ +import argparse +import json +import sys +import numpy as np +from collections import Counter +import matplotlib.pyplot as plt + +def calc_segment_stats_with_sil(blank_idx, sil_idx): + + dataset.init_seq_order() + seq_idx = 0 + inter_sil_seg_len = 0 + init_sil_seg_len = 0 + final_sil_seg_len = 0 + label_seg_len = 0 + num_blank_frames = 0 + num_label_segs = 0 + num_sil_segs = 0 + num_init_sil_segs = 0 + num_final_sil_segs = 0 + num_seqs = 0 + max_seg_len = 0 + + label_dependent_seg_lens = Counter() + label_dependent_num_segs = Counter() + + map_non_sil_seg_len_to_count = Counter() + map_sil_seg_len_to_count = Counter() + + while dataset.is_less_than_num_seqs(seq_idx): + num_seqs += 1 + # progress indication + if seq_idx % 1000 == 0: + complete_frac = dataset.get_complete_frac(seq_idx) + print("Progress: %.02f" % (complete_frac * 100)) + dataset.load_seqs(seq_idx, seq_idx + 1) + data = dataset.get_data(seq_idx, "data") + + # if data[-1] == blank_idx: + # print("LAST IDX IS BLANK!") + # print(data) + # print(dataset.get_tag(seq_idx)) + # print("------------------------") + + # find non-blanks and silence + non_blank_idxs = np.where(data != blank_idx)[0] + sil_idxs = np.where(data == sil_idx)[0] + + non_blank_data = data[data != blank_idx] + + # count number of segments and number of blank frames + num_label_segs += len(non_blank_idxs) - len(sil_idxs) + num_sil_segs += len(sil_idxs) + num_blank_frames += len(data) - len(non_blank_idxs) + + # if there are only blanks, we skip the seq as there are no segments + if non_blank_idxs.size == 0: + seq_idx += 1 + continue + else: + prev_idx = 0 + try: + # go through non blanks and count segment len + # differ between sil_beginning, sil_mid, sil_end and non-sil + for i, idx in enumerate(non_blank_idxs): + seg_len = idx - prev_idx + # first segment is always one too short because of prev_idx = 0 + if prev_idx == 0: + seg_len += 1 + + if seg_len > max_seg_len: + max_seg_len = seg_len + + # if seg_len > 20: + # print("SEQ WITH SEG LEN OVER 20:\n") + # print(data) + # print(dataset.get_tag(seq_idx)) + # print("-------------------------------") + # + # if seg_len > 30: + # print("SEQ WITH SEG LEN OVER 30:\n") + # print(data) + # print(dataset.get_tag(seq_idx)) + # print("-------------------------------") + # + # if seg_len > 40: + # print("SEQ WITH SEG LEN OVER 40:\n") + # print(data) + # print(dataset.get_tag(seq_idx)) + # print("-------------------------------") + # + # if seg_len > 60: + # print("SEQ WITH SEG LEN OVER 60:\n") + # print(data) + # print(dataset.get_tag(seq_idx)) + # print("-------------------------------") + # + # if seg_len > 80: + # print("SEQ WITH SEG LEN OVER 80:\n") + # print(data) + # print(dataset.get_tag(seq_idx)) + # print("-------------------------------") + + label_dependent_seg_lens.update({non_blank_data[i]: seg_len}) + label_dependent_num_segs.update([non_blank_data[i]]) + + if idx in sil_idxs: + if i == 0: + init_sil_seg_len += seg_len + num_init_sil_segs += 1 + map_sil_seg_len_to_count.update([seg_len]) + elif i == len(non_blank_idxs) - 1: + final_sil_seg_len += seg_len + num_final_sil_segs += 1 + map_sil_seg_len_to_count.update([seg_len]) + else: + inter_sil_seg_len += seg_len + map_sil_seg_len_to_count.update([seg_len]) + else: + label_seg_len += seg_len + map_non_sil_seg_len_to_count.update([seg_len]) + + prev_idx = idx + except IndexError: + continue + + seq_idx += 1 + + mean_init_sil_len = init_sil_seg_len / num_init_sil_segs if num_init_sil_segs > 0 else 0 + mean_final_sil_len = final_sil_seg_len / num_final_sil_segs if num_final_sil_segs > 0 else 0 + mean_inter_sil_len = inter_sil_seg_len / (num_sil_segs - num_init_sil_segs - num_final_sil_segs) if inter_sil_seg_len > 0 else 0 + mean_total_sil_len = (init_sil_seg_len + final_sil_seg_len + inter_sil_seg_len) / num_seqs + + mean_label_len = label_seg_len / num_label_segs + mean_total_label_len = label_seg_len / num_seqs + + label_dependent_mean_seg_lens = {int(idx): label_dependent_seg_lens[idx] / label_dependent_num_segs[idx] for idx in label_dependent_seg_lens } + label_dependent_mean_seg_lens.update({idx: mean_label_len for idx in range(blank_idx) if idx not in label_dependent_mean_seg_lens}) + + mean_seq_len = (num_blank_frames + num_sil_segs + num_label_segs) / num_seqs + + num_segments_shorter2 = sum([count for seg_len, count in map_non_sil_seg_len_to_count.items() if seg_len < 2]) + num_segments_shorter4 = sum([count for seg_len, count in map_non_sil_seg_len_to_count.items() if seg_len < 4]) + num_segments_shorter8 = sum([count for seg_len, count in map_non_sil_seg_len_to_count.items() if seg_len < 8]) + num_segments_shorter16 = sum([count for seg_len, count in map_non_sil_seg_len_to_count.items() if seg_len < 16]) + num_segments_shorter32 = sum([count for seg_len, count in map_non_sil_seg_len_to_count.items() if seg_len < 32]) + num_segments_shorter64 = sum([count for seg_len, count in map_non_sil_seg_len_to_count.items() if seg_len < 64]) + num_segments_shorter128 = sum([count for seg_len, count in map_non_sil_seg_len_to_count.items() if seg_len < 128]) + + filename = "statistics" + with open(filename, "w+") as f: + f.write("Segment statistics: \n\n") + f.write("\tSilence: \n") + f.write("\t\tInitial:\n") + f.write("\t\t\tMean length: %f \n" % mean_init_sil_len) + f.write("\t\t\tNum segments: %f \n" % num_init_sil_segs) + f.write("\t\tIntermediate:\n") + f.write("\t\t\tMean length: %f \n" % mean_inter_sil_len) + f.write("\t\t\tNum segments: %f \n" % (num_sil_segs - num_init_sil_segs - num_final_sil_segs)) + f.write("\t\tFinal:\n") + f.write("\t\t\tMean length: %f \n" % mean_final_sil_len) + f.write("\t\t\tNum segments: %f \n" % num_final_sil_segs) + f.write("\t\tTotal per sequence:\n") + f.write("\t\t\tMean length: %f \n" % mean_total_sil_len) + f.write("\t\t\tNum segments: %f \n" % num_sil_segs) + f.write("\n") + f.write("\tNon-silence: \n") + f.write("\t\tMean length per segment: %f \n" % mean_label_len) + f.write("\t\tMean length per sequence: %f \n" % mean_total_label_len) + f.write("\t\tNum segments: %f \n" % num_label_segs) + f.write("\t\tPercent segments shorter than x frames: \n") + f.write("\t\tx = 2: %f \n" % (num_segments_shorter2 / num_label_segs)) + f.write("\t\tx = 4: %f \n" % (num_segments_shorter4 / num_label_segs)) + f.write("\t\tx = 8: %f \n" % (num_segments_shorter8 / num_label_segs)) + f.write("\t\tx = 16: %f \n" % (num_segments_shorter16 / num_label_segs)) + f.write("\t\tx = 32: %f \n" % (num_segments_shorter32 / num_label_segs)) + f.write("\t\tx = 64: %f \n" % (num_segments_shorter64 / num_label_segs)) + f.write("\t\tx = 128: %f \n" % (num_segments_shorter128 / num_label_segs)) + f.write("\n") + f.write("Overall maximum segment length: %d \n" % max_seg_len) + f.write("\n") + f.write("\n") + f.write("Sequence statistics: \n\n") + f.write("\tMean length: %f \n" % mean_seq_len) + f.write("\tNum sequences: %f \n" % num_seqs) + + filename = "mean_non_sil_len" + with open(filename, "w+") as f: + f.write(str(float(mean_label_len))) + + filename = "label_dep_mean_lens" + with open(filename, "w+") as f: + json.dump(label_dependent_mean_seg_lens, f) + + # plot histograms non-sil segment lens + hist_data = [item for seg_len, count in map_non_sil_seg_len_to_count.items() for item in [seg_len] * count] + plt.hist(hist_data, bins=30, range=(0, 50)) + ax = plt.gca() + quantiles = [np.quantile(hist_data, q) for q in [.90, .95, .99]] + for n, q in zip([90, 95, 99], quantiles): + # write quantiles to file + with open("percentile_%s" % n, "w+") as f: + f.write(str(q)) + ax.axvline(q, color="r") + plt.savefig("non_sil_histogram.pdf") + plt.close() + + # plot histograms sil segment lens + hist_data = [item for seg_len, count in map_sil_seg_len_to_count.items() for item in [seg_len] * count] + plt.hist(hist_data, bins=40, range=(0, 100)) + if len(hist_data) != 0: + ax = plt.gca() + quantiles = [np.quantile(hist_data, q) for q in [.90, .95, .99]] + for q in quantiles: + ax.axvline(q, color="r") + plt.savefig("sil_histogram.pdf") + plt.close() + + +def init(hdf_file, seq_list_filter_file): + rnn.init_better_exchook() + rnn.init_thread_join_hack() + dataset_dict = { + "class": "HDFDataset", "files": [hdf_file], "use_cache_manager": True, "seq_list_filter_file": seq_list_filter_file + } + + rnn.init_config(config_filename=None, default_config={"cache_size": 0}) + global config + config = rnn.config + config.set("log", None) + global dataset + dataset = rnn.init_dataset(dataset_dict) + rnn.init_log() + print("Returnn segment-statistics starting up", file=rnn.log.v2) + rnn.returnn_greeting() + rnn.init_faulthandler() + rnn.init_config_json_network() + + +def main(): + arg_parser = argparse.ArgumentParser(description="Calculate segment statistics.") + arg_parser.add_argument("hdf_file", help="hdf file which contains the extracted alignments of some corpus") + arg_parser.add_argument("--seq-list-filter-file", help="whitelist of sequences to use", default=None) + arg_parser.add_argument("--blank-idx", help="the blank index in the alignment", default=0, type=int) + arg_parser.add_argument("--sil-idx", help="the blank index in the alignment", default=None) + arg_parser.add_argument("--returnn-root", help="path to returnn root") + args = arg_parser.parse_args() + sys.path.insert(0, args.returnn_root) + global rnn + import returnn.__main__ as rnn + + init(args.hdf_file, args.seq_list_filter_file) + + try: + calc_segment_stats_with_sil(args.blank_idx, args.sil_idx) + except KeyboardInterrupt: + print("KeyboardInterrupt") + sys.exit(1) + finally: + rnn.finalize() + + +if __name__ == "__main__": + main() diff --git a/users/gruev/tools/segment_statistics.py b/users/gruev/tools/segment_statistics.py new file mode 100755 index 000000000..c88a8f760 --- /dev/null +++ b/users/gruev/tools/segment_statistics.py @@ -0,0 +1,270 @@ +import argparse +import json +import sys +import numpy as np +from collections import Counter +import matplotlib.pyplot as plt + +# This works if blank_idx == 0 +def calc_ctc_segment_stats(num_labels, blank_idx=0): + dataset.init_seq_order() + + # HDF Dataset Utterance (=Sequence) Iteration + # In the end, total_num_seq := (curr_seq_idx + 1) + curr_seq_idx = 0 + + # blank and non-blank frame cnt (across utterances) + num_blank_frames, num_label_frames = 0, 0 + + # start-blank, end-blank, mid-blank (and non-blank) + # individual counts (per sequence) and number of such segments + # start_blank_cnt, mid_blank_cnt, end_blank_cnt = 0, 0, 0 + # start_blank_seg_cnt, end_blank_seg_cnt = 0, 0 + + # Maximal segment duration + max_seg_len = 0 + + # Counter for segment duration indexed by label idx + label_seg_lens = Counter() + # Counter for number of label idx occurrences + label_seg_freq = Counter() + # Counter for number label segment duration occurrences + label_seg_len_freq = Counter() + + # # Counter for blank segments duration + # blank_seg_len = Counter() + # # Counter for non-blank segments duration + # non_blank_seg_len = Counter() + + while dataset.is_less_than_num_seqs(curr_seq_idx): + dataset.load_seqs(curr_seq_idx, curr_seq_idx + 1) + data = dataset.get_data(curr_seq_idx, "data") + + # Filter [blank] only (no [SILENCE]) + blanks_idx = np.where(data == blank_idx)[0] + labels_idx = np.where(data != blank_idx)[0] + labels = data[labels_idx] + + # Counts for [blank] and labels (non-blanks) + num_blank_frames += len(blanks_idx) + num_label_frames += len(labels_idx) + + # If there are only blanks, skip the current sequence + if len(labels_idx) == 0: + curr_seq_idx += 1 + continue + + # Frame duration between non-blank indices + curr_seq_seg_len = np.diff(labels_idx, prepend=-1) + + # Update current max segment duration per sequence + curr_seq_max_seg_len = np.max(curr_seq_seg_len) + if curr_seq_max_seg_len > max_seg_len: + max_seg_len = curr_seq_max_seg_len + + # Update number of label idx occurrences + label_seg_freq.update(labels) + + # Update number of label segment duration occurrences + label_seg_len_freq.update(curr_seq_seg_len) + + # Update durations of label segments + for label, seg_len in zip(labels, curr_seq_seg_len): + label_seg_lens.update({label: seg_len}) + + # # Compute initial, intermediate, final blank frame counts + # first_non_blank, last_non_blank = labels_idx[0], labels_idx[-1] + # if first_non_blank > 0: + # start_blank_cnt += labels_idx[0] + # start_blank_seg_cnt += 1 + # if last_non_blank < len(data) - 1: + # end_blank_cnt += len(data) - labels_idx[-1] - 1 + # end_blank_seg_cnt += 1 + # mid_blank_cnt = np.count_nonzero(data == 0) - start_blank_cnt - end_blank_cnt + + curr_seq_idx += 1 + + # total_seg_cnt = sum(label_seg_freq.values()) # == num_label_frames + + # label_seg_lens holds the corresponding segment lengths + total_seg_len_1 = sum(label_seg_lens.values()) + total_seg_len_2 = sum([k * v for (k, v) in label_seg_len_freq.items()]) + assert total_seg_len_1 == total_seg_len_2 + + total_seg_len = total_seg_len_1 + + # Mean label length per segment + mean_label_seg_len = total_seg_len / num_label_frames + # Mean label length per sequence + mean_label_seq_len = total_seg_len / curr_seq_idx # =: num_seqs + + # Mean duration of segment for each label + mean_label_seg_lens = {} + + for idx in range(num_labels): + if label_seg_freq[idx] == 0: + mean_label_seg_lens[idx] = 0 + else: + mean_label_seg_lens[idx] = label_seg_lens[idx] / label_seg_freq[idx] + + # Mean sequence length + mean_seq_len = (num_blank_frames + num_label_frames) / curr_seq_idx + + # Length statistics of label segments + num_seg_lt2 = sum( + [ + count + for label_seg_len, count in label_seg_len_freq.items() + if label_seg_len < 2 + ] + ) + num_seg_lt4 = sum( + [ + count + for label_seg_len, count in label_seg_len_freq.items() + if label_seg_len < 4 + ] + ) + num_seg_lt8 = sum( + [ + count + for label_seg_len, count in label_seg_len_freq.items() + if label_seg_len < 8 + ] + ) + num_seg_lt16 = sum( + [ + count + for label_seg_len, count in label_seg_len_freq.items() + if label_seg_len < 16 + ] + ) + num_seg_lt32 = sum( + [ + count + for label_seg_len, count in label_seg_len_freq.items() + if label_seg_len < 32 + ] + ) + num_seg_lt64 = sum( + [ + count + for label_seg_len, count in label_seg_len_freq.items() + if label_seg_len < 64 + ] + ) + num_seg_lt128 = sum( + [ + count + for label_seg_len, count in label_seg_len_freq.items() + if label_seg_len < 128 + ] + ) + + # Overview of computed statistics + filename = "statistics.txt" + with open(filename, "w+") as f: + f.write("\tNon-silence: \n") + f.write("\t\tMean length per segment: %f \n" % mean_label_seg_len) + f.write("\t\tMean length per sequence: %f \n" % mean_label_seq_len) + f.write("\t\tNum segments: %f \n" % num_label_frames) + f.write("\t\tPercent segments shorter than x frames: \n") + f.write("\t\tx = 2: %f \n" % (num_seg_lt2 / num_label_frames)) + f.write("\t\tx = 4: %f \n" % (num_seg_lt4 / num_label_frames)) + f.write("\t\tx = 8: %f \n" % (num_seg_lt8 / num_label_frames)) + f.write("\t\tx = 16: %f \n" % (num_seg_lt16 / num_label_frames)) + f.write("\t\tx = 32: %f \n" % (num_seg_lt32 / num_label_frames)) + f.write("\t\tx = 64: %f \n" % (num_seg_lt64 / num_label_frames)) + f.write("\t\tx = 128: %f \n" % (num_seg_lt128 / num_label_frames)) + + f.write("\n") + f.write("Overall maximum segment length: %d \n" % max_seg_len) + f.write("\n\n") + + f.write("Sequence statistics: \n\n") + f.write("\tMean length: %f \n" % mean_seq_len) + f.write("\tNum sequences: %f \n" % curr_seq_idx) + + filename = "mean_label_seq_len.txt" + with open(filename, "w+") as f: + f.write(str(float(mean_label_seq_len))) + + filename = "mean_label_seg_len.txt" + with open(filename, "w+") as f: + f.write(str(float(mean_label_seg_len))) + + filename = "mean_label_seg_lens.json" + with open(filename, "w+") as f: + json.dump(mean_label_seg_lens, f) + + # Histogram for label segment lengths + hist_data = [ + item + for label_seg_len, count in label_seg_len_freq.items() + for item in [label_seg_len] * count + ] + plt.hist(hist_data, bins=40, range=(0, 40)) + ax = plt.gca() + quantiles = [np.quantile(hist_data, q) for q in [0.90, 0.95, 0.99]] + for n, q in zip([90, 95, 99], quantiles): + # Write quantiles to files + with open("quantile_%s" % n, "w+") as f: + f.write(str(q)) + ax.axvline(q, color="r") + plt.savefig("labels_histogram.pdf") + plt.close() + + +def init(hdf_file, seq_list_filter_file): + rnn.init_better_exchook() + rnn.init_thread_join_hack() + dataset_dict = { + "class": "HDFDataset", + "files": [hdf_file], + "use_cache_manager": True, + "seq_list_filter_file": seq_list_filter_file, + } + + rnn.init_config(config_filename=None, default_config={"cache_size": 0}) + global config + config = rnn.config + config.set("log", None) + global dataset + dataset = rnn.init_dataset(dataset_dict) + rnn.init_log() + print("Returnn segment-statistics starting up...", file=rnn.log.v2) + rnn.returnn_greeting() + rnn.init_faulthandler() + rnn.init_config_json_network() + + +def main(): + arg_parser = argparse.ArgumentParser(description="Calculate alignment statistics.") + arg_parser.add_argument( + "hdf_file", + help="hdf file which contains the extracted alignments of some corpus", + ) + arg_parser.add_argument( + "--seq-list-filter-file", help="whitelist of sequences to use", default=None + ) + arg_parser.add_argument( + "--blank-idx", help="the blank index in the alignment", default=0, type=int + ) + arg_parser.add_argument( + "--num-labels", help="the total number of labels in the alignment", type=int + ) + arg_parser.add_argument("--returnn-root", help="path to RETURNN root") + + args = arg_parser.parse_args() + sys.path.insert(0, args.returnn_root) + + global rnn + import returnn.__main__ as rnn + + init(args.hdf_file, args.seq_list_filter_file) + calc_ctc_segment_stats(args.num_labels, args.blank_idx) + rnn.finalize() + + +if __name__ == "__main__": + main()