From 52310c7a2de5daf079d1dfd2a40f4e96e48f3595 Mon Sep 17 00:00:00 2001 From: schmitt Date: Fri, 23 Feb 2024 10:05:34 +0100 Subject: [PATCH] update setup: mini att ilm working now --- .../returnn/config_builder/base.py | 8 + .../returnn/config_builder/global_.py | 69 +++++- .../returnn/config_builder/segmental.py | 39 ++- .../returnn/network_builder/ilm_correction.py | 227 +++++++++++++----- .../center_window_att/base.py | 15 +- .../global_vs_segmental_2022_23/recog.py | 77 ++++-- .../global_vs_segmental_2022_23/train.py | 16 +- 7 files changed, 358 insertions(+), 93 deletions(-) diff --git a/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23/dependencies/returnn/config_builder/base.py b/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23/dependencies/returnn/config_builder/base.py index b8a42b3a9..26a6e5994 100644 --- a/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23/dependencies/returnn/config_builder/base.py +++ b/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23/dependencies/returnn/config_builder/base.py @@ -272,6 +272,14 @@ def edit_network_only_train_length_model(self, net_dict: Dict): def add_align_augment(self, net_dict, networks_dict, python_prolog): raise NotImplementedError + def edit_network_freeze_layers(self, net_dict: Dict, layers_to_exclude: List[str]): + if "class" in net_dict: + net_dict["trainable"] = False + + for item in net_dict: + if type(net_dict[item]) == dict and item not in layers_to_exclude: + self.edit_network_freeze_layers(net_dict[item], layers_to_exclude) + def get_lr_settings(self, lr_opts): lr_settings = {} if lr_opts["type"] == "newbob": diff --git a/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23/dependencies/returnn/config_builder/global_.py b/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23/dependencies/returnn/config_builder/global_.py index 6240c3f51..7cfb61f03 100644 --- a/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23/dependencies/returnn/config_builder/global_.py +++ b/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23/dependencies/returnn/config_builder/global_.py @@ -10,7 +10,8 @@ from i6_experiments.users.schmitt.augmentation.alignment import shift_alignment_boundaries_func_str from i6_experiments.users.schmitt.dynamic_lr import dynamic_lr_str from i6_experiments.users.schmitt.chunking import custom_chunkin_func_str -from i6_experiments.users.schmitt.experiments.config.pipelines.global_vs_segmental_2022_23.dependencies.returnn.network_builder import network_builder +from i6_experiments.users.schmitt.experiments.config.pipelines.global_vs_segmental_2022_23.dependencies.returnn.network_builder import network_builder, ilm_correction +from i6_experiments.users.schmitt.experiments.config.pipelines.global_vs_segmental_2022_23.dependencies.returnn.network_builder.lm import lm_irie from i6_experiments.users.schmitt.experiments.config.pipelines.global_vs_segmental_2022_23.dependencies.returnn import custom_construction_algos from i6_experiments.users.schmitt.experiments.config.pipelines.global_vs_segmental_2022_23.dependencies.returnn.config_builder.base import ConfigBuilder, SWBBlstmConfigBuilder, SwbConformerConfigBuilder, LibrispeechConformerConfigBuilder, ConformerConfigBuilder from i6_experiments.users.schmitt.experiments.config.pipelines.global_vs_segmental_2022_23.dependencies.general.rasr.exes import RasrExecutables @@ -28,11 +29,75 @@ class GlobalConfigBuilder(ConfigBuilder, ABC): - def __init__(self, dependencies: SegmentalLabelDefinition, **kwargs): + def __init__(self, dependencies: GlobalLabelDefinition, **kwargs): super().__init__(dependencies=dependencies, **kwargs) self.dependencies = dependencies + def get_train_config(self, opts: Dict, python_epilog: Optional[Dict] = None): + train_config = super().get_train_config(opts=opts, python_epilog=python_epilog) + print(opts["train_mini_lstm_opts"]) + + if opts.get("train_mini_lstm_opts") is not None: # need to check for None because it can be {} + ilm_correction.add_mini_lstm( + network=train_config.config["network"], + rec_layer_name="output", + train=True, + mini_att_in_s_for_train=opts["train_mini_lstm_opts"].get("mini_att_in_s", False) + ) + self.edit_network_freeze_layers( + train_config.config["network"], + layers_to_exclude=["mini_att_lstm", "mini_att"] + ) + train_config.config["network"]["output"]["trainable"] = True + train_config.config["network"]["output"]["unit"]["att"]["is_output_layer"] = True + + if opts["train_mini_lstm_opts"].get("use_se_loss", False): + ilm_correction.add_se_loss( + network=train_config.config["network"], + rec_layer_name="output", + ) + + return train_config + + def get_recog_config(self, opts: Dict): + recog_config = super().get_recog_config(opts) + + ilm_correction_opts = opts.get("ilm_correction_opts", None) + if ilm_correction_opts is not None: + ilm_correction.add_mini_lstm( + network=recog_config.config["network"], + rec_layer_name="output", + train=False + ) + + ilm_correction.add_ilm_correction( + network=recog_config.config["network"], + rec_layer_name="output", + target_num_labels=self.dependencies.model_hyperparameters.target_num_labels, + opts=ilm_correction_opts, + label_prob_layer="output_prob" + ) + + recog_config.config["preload_from_files"] = { + "mini_lstm": { + "filename": ilm_correction_opts["mini_att_checkpoint"], + "prefix": "preload_", + } + } + + lm_opts = opts.get("lm_opts", None) + if lm_opts is not None: + lm_irie.add_lm( + network=recog_config.config["network"], + rec_layer_name="output", + target_num_labels=self.dependencies.model_hyperparameters.target_num_labels, + opts=lm_opts, + label_prob_layer="output_prob" + ) + + return recog_config + def get_compile_tf_graph_config(self, opts: Dict): returnn_config = self.get_recog_config(opts) del returnn_config.config["network"]["output"]["max_seq_len"] diff --git a/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23/dependencies/returnn/config_builder/segmental.py b/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23/dependencies/returnn/config_builder/segmental.py index a48ab2889..56187ca44 100644 --- a/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23/dependencies/returnn/config_builder/segmental.py +++ b/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23/dependencies/returnn/config_builder/segmental.py @@ -11,6 +11,7 @@ from i6_experiments.users.schmitt.dynamic_lr import dynamic_lr_str from i6_experiments.users.schmitt.chunking import custom_chunkin_func_str, custom_chunkin_w_reduction_func_str from i6_experiments.users.schmitt.experiments.config.pipelines.global_vs_segmental_2022_23.dependencies.returnn.network_builder import network_builder, network_builder2, ilm_correction +from i6_experiments.users.schmitt.experiments.config.pipelines.global_vs_segmental_2022_23.dependencies.returnn.network_builder.lm import lm_irie from i6_experiments.users.schmitt.experiments.config.pipelines.global_vs_segmental_2022_23.dependencies.returnn import custom_construction_algos from i6_experiments.users.schmitt.experiments.config.pipelines.global_vs_segmental_2022_23.dependencies.returnn.config_builder.base import ConfigBuilder, SWBBlstmConfigBuilder, SwbConformerConfigBuilder, LibrispeechConformerConfigBuilder, ConformerConfigBuilder from i6_experiments.users.schmitt.experiments.config.pipelines.global_vs_segmental_2022_23.dependencies.general.rasr.exes import RasrExecutables @@ -73,6 +74,7 @@ def get_train_config(self, opts: Dict, python_epilog: Optional[Dict] = None): network=train_config.config["network"], rec_layer_name="label_model", train=True, + mini_att_in_s_for_train=opts["train_mini_lstm_opts"].get("mini_att_in_s", False) ) self.edit_network_freeze_layers( train_config.config["network"], @@ -80,6 +82,12 @@ def get_train_config(self, opts: Dict, python_epilog: Optional[Dict] = None): ) train_config.config["network"]["label_model"]["trainable"] = True + if opts["train_mini_lstm_opts"].get("use_se_loss", False): + ilm_correction.add_se_loss( + network=train_config.config["network"], + rec_layer_name="label_model", + ) + return train_config def get_recog_config(self, opts: Dict): @@ -90,7 +98,8 @@ def get_recog_config(self, opts: Dict): ilm_correction.add_mini_lstm( network=recog_config.config["network"], rec_layer_name="output", - train=False + train=False, + use_mask_layer=True ) if ilm_correction_opts.get("correct_eos", False): @@ -100,7 +109,25 @@ def get_recog_config(self, opts: Dict): network=recog_config.config["network"], rec_layer_name="output", target_num_labels=self.dependencies.model_hyperparameters.target_num_labels_wo_blank, - opts=ilm_correction_opts + opts=ilm_correction_opts, + label_prob_layer="label_log_prob" + ) + + recog_config.config["preload_from_files"] = { + "mini_lstm": { + "filename": ilm_correction_opts["mini_att_checkpoint"], + "prefix": "preload_", + } + } + + lm_opts = opts.get("lm_opts", None) + if lm_opts is not None: + lm_irie.add_lm( + network=recog_config.config["network"], + rec_layer_name="output", + target_num_labels=self.dependencies.model_hyperparameters.target_num_labels_wo_blank, + opts=lm_opts, + label_prob_layer="label_log_prob" ) return recog_config @@ -116,14 +143,6 @@ def edit_network_only_train_length_model(self, net_dict: Dict): if type(net_dict[item]) == dict and item != "output": self.edit_network_only_train_length_model(net_dict[item]) - def edit_network_freeze_layers(self, net_dict: Dict, layers_to_exclude: List[str]): - if "class" in net_dict: - net_dict["trainable"] = False - - for item in net_dict: - if type(net_dict[item]) == dict and item not in layers_to_exclude: - self.edit_network_freeze_layers(net_dict[item], layers_to_exclude) - def get_dump_scores_config(self, corpus_key: str, opts: Dict): returnn_config = self.get_eval_config(eval_corpus_key=corpus_key, opts=opts) diff --git a/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23/dependencies/returnn/network_builder/ilm_correction.py b/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23/dependencies/returnn/network_builder/ilm_correction.py index 0a528ab6d..9c18692fd 100644 --- a/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23/dependencies/returnn/network_builder/ilm_correction.py +++ b/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23/dependencies/returnn/network_builder/ilm_correction.py @@ -8,47 +8,79 @@ def add_mini_lstm( network: Dict, rec_layer_name: str, - train: bool = True + train: bool = True, + use_mask_layer: bool = False, + mini_att_in_s_for_train: bool = False ): + assert not (use_mask_layer and train), "mask layer only makes sense for inference" network[rec_layer_name]["unit"].update({ - "mini_att_lstm": { - "class": "rec", - "unit": "nativelstm2", - "n_out": 50, - "direction": 1, - "from": "prev:target_embed", - }, - "mini_att": { + "mini_att" if train else "preload_mini_att": { "class": "linear", "activation": None, "with_bias": True, - "from": "mini_att_lstm", + "from": "mini_att_lstm" if train else "preload_mini_att_lstm", "n_out": 512, }, }) + mini_att_lstm_dict = { + "mini_att_lstm" if train else "preload_mini_att_lstm": { + "class": "rec", + "unit": "nativelstm2", + "n_out": 50, + "direction": 1, + "from": "prev:target_embed", + }, + } + + if use_mask_layer: + network[rec_layer_name]["unit"].update({ + "preload_mini_att_lstm": { + "class": "unmask", + "from": "preload_mini_att_lstm_masked", + "mask": "prev:output_emit", + }, + "preload_mini_att_lstm_masked": { + "class": "masked_computation", + "from": "prev:target_embed", + "mask": "prev:output_emit", + "unit": { + "class": "subnetwork", + "from": "data", + "subnetwork": { + **mini_att_lstm_dict, + "output": { + "class": "copy", + "from": "preload_mini_att_lstm", + } + } + } + } + }) + network[rec_layer_name]["unit"]["preload_mini_att_lstm_masked"]["unit"]["subnetwork"]["preload_mini_att_lstm"]["from"] = "data" + network[rec_layer_name]["unit"]["preload_mini_att_lstm_masked"]["unit"]["subnetwork"]["preload_mini_att_lstm"]["name_scope"] = "/output/rec/preload_mini_att_lstm/rec" + else: + network[rec_layer_name]["unit"].update(mini_att_lstm_dict) + if train: network[rec_layer_name]["unit"]["readout_in"]["from"] = ["s", "prev:target_embed", "mini_att"] + if mini_att_in_s_for_train: + network[rec_layer_name]["unit"]["s"]["from"] = ["prev:target_embed", "prev:mini_att"] -def add_ilm_correction(network: Dict, rec_layer_name: str, target_num_labels: int, opts: Dict): +def add_ilm_correction( + network: Dict, + rec_layer_name: str, + target_num_labels: int, + opts: Dict, + label_prob_layer: str +): network[rec_layer_name]["unit"].update({ - "prior_s": { - "class": "rnn_cell", - "unit": "zoneoutlstm", - "n_out": 1024, - "from": ["prev:target_embed", "prev:mini_att"], - "unit_opts": { - "zoneout_factor_cell": 0.15, - "zoneout_factor_output": 0.05, - }, - "reuse_params": "s_masked/s" - }, "prior_readout_in": { "class": "linear", "activation": None, "with_bias": True, - "from": ["prior_s", "prev:target_embed", "mini_att"], + "from": ["prior_s", "prev:target_embed", "preload_mini_att"], "n_out": 1024, "reuse_params": "readout_in" }, @@ -61,47 +93,134 @@ def add_ilm_correction(network: Dict, rec_layer_name: str, target_num_labels: in "prior_label_prob": { "class": "softmax", "from": ["prior_readout"], - "reuse_params": "label_log_prob", "n_out": target_num_labels, }, }) - if rec_layer_name == "output": + prior_s_dict = { + "prior_s": { + "class": "rnn_cell", + "unit": "zoneoutlstm", + "n_out": 1024, + "from": ["prev:target_embed", "prev:preload_mini_att"], + "unit_opts": { + "zoneout_factor_cell": 0.15, + "zoneout_factor_output": 0.05, + }, + }, + } + + if label_prob_layer == "label_log_prob": network[rec_layer_name]["unit"].update({ - "label_log_prob_wo_ilm": { - "class": "eval", - "eval": f"source(0) - {opts['scale']} * safe_log(source(1))", - "from": ["label_log_prob", "prior_label_prob"], + "prior_s": { + "class": "unmask", + "from": "prior_s_masked", + "mask": "prev:output_emit", }, + "prior_s_masked": { + "class": "masked_computation", + "from": ["prev:target_embed"], + "mask": "prev:output_emit", + "unit": { + "class": "subnetwork", + "from": "data", + "subnetwork": { + **prior_s_dict, + "output": { + "class": "copy", + "from": "prior_s", + } + } + } + } }) - network[rec_layer_name]["unit"]["label_log_prob_plus_emit"]["from"] = ["label_log_prob_wo_ilm", "emit_log_prob"] + network[rec_layer_name]["unit"]["prior_s_masked"]["unit"]["subnetwork"]["prior_s"]["from"] = ["data", "base:prev:preload_mini_att"] + network[rec_layer_name]["unit"]["prior_s_masked"]["unit"]["subnetwork"]["prior_s"]["name_scope"] = "/output/rec/s/rec" + network[rec_layer_name]["unit"]["prior_label_prob"]["reuse_params"] = "label_log_prob" + else: + assert label_prob_layer == "output_prob" + network[rec_layer_name]["unit"].update(prior_s_dict) + network[rec_layer_name]["unit"]["prior_s"]["reuse_params"] = "s" + network[rec_layer_name]["unit"]["prior_label_prob"]["reuse_params"] = "output_prob" - if opts["correct_eos"]: - add_is_last_frame_condition(network, rec_layer_name) # adds layer "is_last_frame" + combo_label_prob_layer = f"combo_{label_prob_layer}" + if combo_label_prob_layer in network[rec_layer_name]["unit"]: + network[rec_layer_name]["unit"][combo_label_prob_layer]["from"].append("prior_label_prob") + network[rec_layer_name]["unit"][combo_label_prob_layer]["eval"] += f" - {opts['scale']} * safe_log(source(2))" + else: network[rec_layer_name]["unit"].update({ - "ilm_eos_prob": { - "class": "gather", - "from": "prior_label_prob", - "position": opts["eos_idx"], - "axis": "f", - }, - "ilm_eos_log_prob0": { + combo_label_prob_layer: { "class": "eval", - "eval": "safe_log(source(0))", - "from": "ilm_eos_prob", + "from": [label_prob_layer, "prior_label_prob"], }, - "ilm_eos_log_prob": { # this layer is only non-zero for the last frame - "class": "switch", - "condition": "is_last_frame", - "true_from": "ilm_eos_log_prob0", - "false_from": 0.0, - } }) + if label_prob_layer == "label_log_prob": + network[rec_layer_name]["unit"][combo_label_prob_layer][ + "eval"] = f"source(0) - {opts['scale']} * safe_log(source(1))" + network[rec_layer_name]["unit"]["label_log_prob_plus_emit"]["from"] = [combo_label_prob_layer, "emit_log_prob"] + else: + assert label_prob_layer == "output_prob" + network[rec_layer_name]["unit"][combo_label_prob_layer][ + "eval"] = f"safe_log(source(0)) - {opts['scale']} * safe_log(source(1))" + network[rec_layer_name]["unit"]["output"]["from"] = combo_label_prob_layer + network[rec_layer_name]["unit"]["output"]["input_type"] = "log_prob" + + # special eos handling only for segmental models + if label_prob_layer == "label_log_prob": + if opts["correct_eos"]: + add_is_last_frame_condition(network, rec_layer_name) # adds layer "is_last_frame" + network[rec_layer_name]["unit"].update({ + "ilm_eos_prob": { + "class": "gather", + "from": "prior_label_prob", + "position": opts["eos_idx"], + "axis": "f", + }, + "ilm_eos_log_prob0": { + "class": "eval", + "eval": "safe_log(source(0))", + "from": "ilm_eos_prob", + }, + "ilm_eos_log_prob": { # this layer is only non-zero for the last frame + "class": "switch", + "condition": "is_last_frame", + "true_from": "ilm_eos_log_prob0", + "false_from": 0.0, + } + }) + + assert network[rec_layer_name]["unit"]["blank_log_prob"]["from"] == "emit_prob0" \ + or network[rec_layer_name]["unit"]["blank_log_prob"]["from"] == ["emit_prob0", "lm_eos_log_prob"] \ + and network[rec_layer_name]["unit"]["blank_log_prob"]["eval"].startswith("tf.math.log_sigmoid(-source(0))"), ( + "blank_log_prob layer is not as expected" + ) + blank_log_prob_layer = network[rec_layer_name]["unit"]["blank_log_prob"] + if type(blank_log_prob_layer["from"]) is str: + blank_log_prob_layer["from"] = [blank_log_prob_layer["from"]] + + # in the last frame, we want to subtract the ilm eos log prob from the blank log prob + blank_log_prob_layer["from"].append("ilm_eos_log_prob") + blank_log_prob_layer["eval"] += f" - source({len(blank_log_prob_layer['from']) - 1})" + + +def add_se_loss( + network: Dict, + rec_layer_name: str, +): + network[rec_layer_name]["unit"].update({ + "se_loss": { + "class": "eval", + "eval": "(source(0) - source(1)) ** 2", + "from": ["att", "mini_att"], + }, + "att_loss": { + "class": "reduce", + "mode": "mean", + "axis": "F", + "from": "se_loss", + "loss": "as_is", + "loss_scale": 0.05, + }, + }) - assert network[rec_layer_name]["unit"]["blank_log_prob"]["from"] == "emit_prob0" \ - and network[rec_layer_name]["unit"]["blank_log_prob"]["eval"] == "tf.math.log_sigmoid(-source(0))", ( - "blank_log_prob layer is not as expected" - ) - # in the last frame, we want to subtract the ilm eos log prob from the blank log prob - network[rec_layer_name]["unit"]["blank_log_prob"]["from"] = ["emit_prob0", "ilm_eos_log_prob"] - network[rec_layer_name]["unit"]["blank_log_prob"]["eval"] = "tf.math.log_sigmoid(-source(0)) - source(1)" + network[rec_layer_name]["unit"]["att"]["axes"] = "except_time" diff --git a/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23/pipelines/pipeline_ls_conf/center_window_att/base.py b/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23/pipelines/pipeline_ls_conf/center_window_att/base.py index 25264dd9b..231494b38 100644 --- a/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23/pipelines/pipeline_ls_conf/center_window_att/base.py +++ b/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23/pipelines/pipeline_ls_conf/center_window_att/base.py @@ -25,6 +25,7 @@ "no_ctc_loss": False, "lr_opts": None, "train_mini_lstm_opts": None, + "cleanup_old_models": None } default_returnn_recog_opts = { "search_corpus_key": "dev-other", @@ -34,6 +35,8 @@ "analyse": True, "att_weight_seq_tags": None, "load_ignore_missing_vars": False, + "lm_opts": None, + "ilm_correction_opts": None, } default_rasr_recog_opts = { "search_corpus_key": "dev-other", @@ -127,6 +130,8 @@ def returnn_recog_center_window_att_import_global( search_rqmt=_recog_opts["search_rqmt"], batch_size=_recog_opts["batch_size"], load_ignore_missing_vars=_recog_opts["load_ignore_missing_vars"], + lm_opts=_recog_opts["lm_opts"], + ilm_correction_opts=_recog_opts["ilm_correction_opts"], ) recog_exp.run_eval() @@ -244,6 +249,7 @@ def train_center_window_att_import_global( only_train_length_model=_train_opts["only_train_length_model"], no_ctc_loss=_train_opts["no_ctc_loss"], train_mini_lstm_opts=_train_opts["train_mini_lstm_opts"], + cleanup_old_models=_train_opts["cleanup_old_models"], ) return train_exp.run_train() @@ -275,6 +281,7 @@ def standard_train_recog_center_window_att_import_global( ) train_opts["align_targets"] = align_targets + train_mini_att_num_epochs = train_mini_lstm_opts.pop("num_epochs") train_opts.update({ "train_mini_lstm_opts": train_mini_lstm_opts, "import_model_train_epoch1": checkpoints[train_opts["num_epochs"]], @@ -288,14 +295,18 @@ def standard_train_recog_center_window_att_import_global( "learning_rate_control_error_measure": "dev_error_label_model/output_prob" } }) - checkpoints, model_dir, learning_rates = train_center_window_att_import_global( + mini_att_checkpoints, model_dir, learning_rates = train_center_window_att_import_global( alias=alias, config_builder=config_builder, train_opts=train_opts, ) - checkpoint = checkpoints[10] + mini_att_checkpoint = mini_att_checkpoints[train_mini_att_num_epochs] _recog_opts = copy.deepcopy(recog_opts) + + if isinstance(_recog_opts, dict) and "ilm_correction_opts" in _recog_opts: + _recog_opts["ilm_correction_opts"]["mini_att_checkpoint"] = mini_att_checkpoint + if _recog_opts is None or _recog_opts.pop("returnn_recog", True): returnn_recog_center_window_att_import_global( alias=alias, diff --git a/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23/recog.py b/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23/recog.py index 41c1a26ee..09e7688a5 100644 --- a/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23/recog.py +++ b/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23/recog.py @@ -36,17 +36,48 @@ def __init__( config_builder: ConfigBuilder, checkpoint: Checkpoint, corpus_key: str, + ilm_correction_opts: Optional[Dict] = None, ): self.config_builder = config_builder self.checkpoint = checkpoint self.corpus_key = corpus_key self.stm_corpus_key = corpus_key + self.ilm_correction_opts = ilm_correction_opts self.alias = alias self.returnn_python_exe = self.config_builder.variant_params["returnn_python_exe"] self.returnn_root = self.config_builder.variant_params["returnn_root"] + def get_ilm_correction_alias(self, alias: str): + if self.ilm_correction_opts is not None: + alias += "/ilm_correction_scale-%f" % self.ilm_correction_opts["scale"] + if "correct_eos" in self.ilm_correction_opts: + if self.ilm_correction_opts["correct_eos"]: + alias += "/correct_eos" + else: + alias += "/wo_correct_eos" + if self.ilm_correction_opts.get("mini_att_in_s_for_train", False): + alias += "/w_mini_att_in_s_for_train" + else: + alias += "/wo_mini_att_in_s_for_train" + if self.ilm_correction_opts.get("use_se_loss", False): + alias += "/w_se_loss" + else: + alias += "/wo_se_loss" + if self.ilm_correction_opts.get("mini_att_train_num_epochs", None): + alias += "/mini_att_train_num_epochs-%d" % self.ilm_correction_opts["mini_att_train_num_epochs"] + else: + alias += "/wo_ilm_correction" + + return alias + + def get_config_recog_opts(self): + return { + "search_corpus_key": self.corpus_key, + "ilm_correction_opts": self.ilm_correction_opts, + } + @abstractmethod def get_ctm_path(self) -> Path: pass @@ -78,6 +109,7 @@ def __init__( search_rqmt: Optional[Dict], batch_size: Optional[int], load_ignore_missing_vars: bool = False, + lm_opts: Optional[Dict] = None, **kwargs): super().__init__(**kwargs) @@ -88,19 +120,33 @@ def __init__( self.concat_num = concat_num self.search_rqmt = search_rqmt self.load_ignore_missing_vars = load_ignore_missing_vars + self.lm_opts = lm_opts self.alias += "/returnn_decoding" - def get_recog_opts(self): - return { - "search_corpus_key": self.corpus_key, + if lm_opts is not None: + self.alias += "/bpe-lm-scale-%f" % (lm_opts["scale"],) + if "add_lm_eos_last_frame" in lm_opts: + self.alias += "_add-lm-eos-%s" % lm_opts["add_lm_eos_last_frame"] + self.alias = self.get_ilm_correction_alias(self.alias) + else: + self.alias += "/no-lm" + if self.ilm_correction_opts is not None: + self.alias = self.get_ilm_correction_alias(self.alias) + + def get_config_recog_opts(self): + recog_opts = super().get_config_recog_opts() + recog_opts.update({ "batch_size": self.batch_size, "dataset_opts": {"concat_num": self.concat_num}, "load_ignore_missing_vars": self.load_ignore_missing_vars, } + "lm_opts": self.lm_opts, + }) + return recog_opts def get_ctm_path(self) -> Path: - recog_config = self.config_builder.get_recog_config(opts=self.get_recog_opts()) + recog_config = self.config_builder.get_recog_config(opts=self.get_config_recog_opts()) device = "gpu" if self.search_rqmt and self.search_rqmt["gpu"] == 0: @@ -246,7 +292,6 @@ def __init__( lm_lookahead_opts: Optional[Dict] = None, open_vocab: bool = True, segment_list: Optional[List[str]] = None, - ilm_correction_opts: Optional[Dict] = None, **kwargs): super().__init__(**kwargs) @@ -269,8 +314,6 @@ def __init__( self.native_lstm2_so_path = native_lstm2_so_path self.segment_list = segment_list - self.ilm_correction_opts = ilm_correction_opts - self.alias += "/rasr_decoding/max-seg-len-%d" % self.max_segment_len if simple_beam_search: @@ -284,9 +327,12 @@ def __init__( if lm_opts is not None: self.lm_opts = copy.deepcopy(lm_opts) self.alias += "/lm-%s_scale-%f" % (lm_opts["type"], lm_opts["scale"]) + self.alias = self.get_ilm_correction_alias(self.alias) else: self.lm_opts = copy.deepcopy(self._get_default_lm_opts()) self.alias += "/no_lm" + if self.ilm_correction_opts is not None: + self.alias = self.get_ilm_correction_alias(self.alias) if self.lm_lookahead_opts is not None: self.lm_lookahead_opts = copy.deepcopy(lm_lookahead_opts) @@ -295,23 +341,8 @@ def __init__( self.lm_lookahead_opts = copy.deepcopy(self._get_default_lm_lookahead_opts()) self.alias += "/wo-lm-lookahead" - if self.ilm_correction_opts is not None: - self.alias += "/ilm_correction_scale-%f" % self.ilm_correction_opts["scale"] - if self.ilm_correction_opts["correct_eos"]: - self.alias += "_correct_eos" - else: - self.alias += "_wo_correct_eos" - else: - self.alias += "/wo_ilm_correction" - - def get_recog_opts(self): - return { - "search_corpus_key": self.corpus_key, - "ilm_correction_opts": self.ilm_correction_opts, - } - def _get_returnn_graph(self) -> Path: - recog_config = self.config_builder.get_compile_tf_graph_config(opts=self.get_recog_opts()) + recog_config = self.config_builder.get_compile_tf_graph_config(opts=self.get_config_recog_opts()) recog_config.config["network"]["output"]["unit"]["target_embed_masked"]["unit"]["subnetwork"]["target_embed0"]["safe_embedding"] = True compile_job = CompileTFGraphJob( returnn_config=recog_config, diff --git a/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23/train.py b/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23/train.py index e900fbc35..9a25306cc 100644 --- a/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23/train.py +++ b/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23/train.py @@ -142,9 +142,21 @@ def run_train( if train_opts.get("train_mini_lstm_opts") is not None: # need to check for None because it can be {} alias = alias + "_mini_lstm" if train_opts["train_mini_lstm_opts"].get("use_eos", False): - alias = alias + "_w_eos" + alias = alias + "/w_eos" else: - alias = alias + "_wo_eos" + alias = alias + "/wo_eos" + + if train_opts["train_mini_lstm_opts"].get("mini_att_in_s", False): + alias = alias + "/w_mini_att_in_s" + else: + alias = alias + "/wo_mini_att_in_s" + + if train_opts["train_mini_lstm_opts"].get("use_se_loss", False): + alias = alias + "/w_se_loss" + else: + alias = alias + "/wo_se_loss" + + alias += "_%d_epochs" % n_epochs train_job = ReturnnTrainingJob( train_config,