Skip to content

Commit

Permalink
update setup: mini att ilm working now
Browse files Browse the repository at this point in the history
  • Loading branch information
robin-p-schmitt committed Mar 5, 2024
1 parent 98a7523 commit 52310c7
Show file tree
Hide file tree
Showing 7 changed files with 358 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,14 @@ def edit_network_only_train_length_model(self, net_dict: Dict):
def add_align_augment(self, net_dict, networks_dict, python_prolog):
raise NotImplementedError

def edit_network_freeze_layers(self, net_dict: Dict, layers_to_exclude: List[str]):
if "class" in net_dict:
net_dict["trainable"] = False

for item in net_dict:
if type(net_dict[item]) == dict and item not in layers_to_exclude:
self.edit_network_freeze_layers(net_dict[item], layers_to_exclude)

def get_lr_settings(self, lr_opts):
lr_settings = {}
if lr_opts["type"] == "newbob":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from i6_experiments.users.schmitt.augmentation.alignment import shift_alignment_boundaries_func_str
from i6_experiments.users.schmitt.dynamic_lr import dynamic_lr_str
from i6_experiments.users.schmitt.chunking import custom_chunkin_func_str
from i6_experiments.users.schmitt.experiments.config.pipelines.global_vs_segmental_2022_23.dependencies.returnn.network_builder import network_builder
from i6_experiments.users.schmitt.experiments.config.pipelines.global_vs_segmental_2022_23.dependencies.returnn.network_builder import network_builder, ilm_correction
from i6_experiments.users.schmitt.experiments.config.pipelines.global_vs_segmental_2022_23.dependencies.returnn.network_builder.lm import lm_irie
from i6_experiments.users.schmitt.experiments.config.pipelines.global_vs_segmental_2022_23.dependencies.returnn import custom_construction_algos
from i6_experiments.users.schmitt.experiments.config.pipelines.global_vs_segmental_2022_23.dependencies.returnn.config_builder.base import ConfigBuilder, SWBBlstmConfigBuilder, SwbConformerConfigBuilder, LibrispeechConformerConfigBuilder, ConformerConfigBuilder
from i6_experiments.users.schmitt.experiments.config.pipelines.global_vs_segmental_2022_23.dependencies.general.rasr.exes import RasrExecutables
Expand All @@ -28,11 +29,75 @@


class GlobalConfigBuilder(ConfigBuilder, ABC):
def __init__(self, dependencies: SegmentalLabelDefinition, **kwargs):
def __init__(self, dependencies: GlobalLabelDefinition, **kwargs):
super().__init__(dependencies=dependencies, **kwargs)

self.dependencies = dependencies

def get_train_config(self, opts: Dict, python_epilog: Optional[Dict] = None):
train_config = super().get_train_config(opts=opts, python_epilog=python_epilog)
print(opts["train_mini_lstm_opts"])

if opts.get("train_mini_lstm_opts") is not None: # need to check for None because it can be {}
ilm_correction.add_mini_lstm(
network=train_config.config["network"],
rec_layer_name="output",
train=True,
mini_att_in_s_for_train=opts["train_mini_lstm_opts"].get("mini_att_in_s", False)
)
self.edit_network_freeze_layers(
train_config.config["network"],
layers_to_exclude=["mini_att_lstm", "mini_att"]
)
train_config.config["network"]["output"]["trainable"] = True
train_config.config["network"]["output"]["unit"]["att"]["is_output_layer"] = True

if opts["train_mini_lstm_opts"].get("use_se_loss", False):
ilm_correction.add_se_loss(
network=train_config.config["network"],
rec_layer_name="output",
)

return train_config

def get_recog_config(self, opts: Dict):
recog_config = super().get_recog_config(opts)

ilm_correction_opts = opts.get("ilm_correction_opts", None)
if ilm_correction_opts is not None:
ilm_correction.add_mini_lstm(
network=recog_config.config["network"],
rec_layer_name="output",
train=False
)

ilm_correction.add_ilm_correction(
network=recog_config.config["network"],
rec_layer_name="output",
target_num_labels=self.dependencies.model_hyperparameters.target_num_labels,
opts=ilm_correction_opts,
label_prob_layer="output_prob"
)

recog_config.config["preload_from_files"] = {
"mini_lstm": {
"filename": ilm_correction_opts["mini_att_checkpoint"],
"prefix": "preload_",
}
}

lm_opts = opts.get("lm_opts", None)
if lm_opts is not None:
lm_irie.add_lm(
network=recog_config.config["network"],
rec_layer_name="output",
target_num_labels=self.dependencies.model_hyperparameters.target_num_labels,
opts=lm_opts,
label_prob_layer="output_prob"
)

return recog_config

def get_compile_tf_graph_config(self, opts: Dict):
returnn_config = self.get_recog_config(opts)
del returnn_config.config["network"]["output"]["max_seq_len"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from i6_experiments.users.schmitt.dynamic_lr import dynamic_lr_str
from i6_experiments.users.schmitt.chunking import custom_chunkin_func_str, custom_chunkin_w_reduction_func_str
from i6_experiments.users.schmitt.experiments.config.pipelines.global_vs_segmental_2022_23.dependencies.returnn.network_builder import network_builder, network_builder2, ilm_correction
from i6_experiments.users.schmitt.experiments.config.pipelines.global_vs_segmental_2022_23.dependencies.returnn.network_builder.lm import lm_irie
from i6_experiments.users.schmitt.experiments.config.pipelines.global_vs_segmental_2022_23.dependencies.returnn import custom_construction_algos
from i6_experiments.users.schmitt.experiments.config.pipelines.global_vs_segmental_2022_23.dependencies.returnn.config_builder.base import ConfigBuilder, SWBBlstmConfigBuilder, SwbConformerConfigBuilder, LibrispeechConformerConfigBuilder, ConformerConfigBuilder
from i6_experiments.users.schmitt.experiments.config.pipelines.global_vs_segmental_2022_23.dependencies.general.rasr.exes import RasrExecutables
Expand Down Expand Up @@ -73,13 +74,20 @@ def get_train_config(self, opts: Dict, python_epilog: Optional[Dict] = None):
network=train_config.config["network"],
rec_layer_name="label_model",
train=True,
mini_att_in_s_for_train=opts["train_mini_lstm_opts"].get("mini_att_in_s", False)
)
self.edit_network_freeze_layers(
train_config.config["network"],
layers_to_exclude=["mini_att_lstm", "mini_att"]
)
train_config.config["network"]["label_model"]["trainable"] = True

if opts["train_mini_lstm_opts"].get("use_se_loss", False):
ilm_correction.add_se_loss(
network=train_config.config["network"],
rec_layer_name="label_model",
)

return train_config

def get_recog_config(self, opts: Dict):
Expand All @@ -90,7 +98,8 @@ def get_recog_config(self, opts: Dict):
ilm_correction.add_mini_lstm(
network=recog_config.config["network"],
rec_layer_name="output",
train=False
train=False,
use_mask_layer=True
)

if ilm_correction_opts.get("correct_eos", False):
Expand All @@ -100,7 +109,25 @@ def get_recog_config(self, opts: Dict):
network=recog_config.config["network"],
rec_layer_name="output",
target_num_labels=self.dependencies.model_hyperparameters.target_num_labels_wo_blank,
opts=ilm_correction_opts
opts=ilm_correction_opts,
label_prob_layer="label_log_prob"
)

recog_config.config["preload_from_files"] = {
"mini_lstm": {
"filename": ilm_correction_opts["mini_att_checkpoint"],
"prefix": "preload_",
}
}

lm_opts = opts.get("lm_opts", None)
if lm_opts is not None:
lm_irie.add_lm(
network=recog_config.config["network"],
rec_layer_name="output",
target_num_labels=self.dependencies.model_hyperparameters.target_num_labels_wo_blank,
opts=lm_opts,
label_prob_layer="label_log_prob"
)

return recog_config
Expand All @@ -116,14 +143,6 @@ def edit_network_only_train_length_model(self, net_dict: Dict):
if type(net_dict[item]) == dict and item != "output":
self.edit_network_only_train_length_model(net_dict[item])

def edit_network_freeze_layers(self, net_dict: Dict, layers_to_exclude: List[str]):
if "class" in net_dict:
net_dict["trainable"] = False

for item in net_dict:
if type(net_dict[item]) == dict and item not in layers_to_exclude:
self.edit_network_freeze_layers(net_dict[item], layers_to_exclude)

def get_dump_scores_config(self, corpus_key: str, opts: Dict):
returnn_config = self.get_eval_config(eval_corpus_key=corpus_key, opts=opts)

Expand Down
Loading

0 comments on commit 52310c7

Please sign in to comment.