From 2b0d18135dff9b016b3e0f51a6d17deaf1f7e8a7 Mon Sep 17 00:00:00 2001 From: marvin84 Date: Mon, 27 Nov 2023 15:42:33 +0100 Subject: [PATCH] continue with LFR models --- .../common/BASE_factored_hybrid_system.py | 2 + .../common/TF_factored_hybrid_system.py | 18 ++- .../setups/common/encoder/blstm/layer.py | 6 +- .../common/encoder/conformer/best_setup.py | 33 ++-- .../common/encoder/conformer/conformer.py | 24 +-- .../encoder/conformer/transformer_network.py | 22 +-- .../setups/common/helpers/network/__init__.py | 1 + .../setups/common/helpers/network/augment.py | 144 +++++++++++++----- .../setups/common/helpers/network/aux_loss.py | 37 +++-- .../common/helpers/network/frame_rate.py | 16 ++ .../common/helpers/train/network_params.py | 7 +- .../common/helpers/train/returnn_time_tag.py | 11 +- 12 files changed, 214 insertions(+), 107 deletions(-) create mode 100644 users/raissi/setups/common/helpers/network/frame_rate.py diff --git a/users/raissi/setups/common/BASE_factored_hybrid_system.py b/users/raissi/setups/common/BASE_factored_hybrid_system.py index c15480d61..e282076dc 100644 --- a/users/raissi/setups/common/BASE_factored_hybrid_system.py +++ b/users/raissi/setups/common/BASE_factored_hybrid_system.py @@ -61,6 +61,7 @@ from i6_experiments.users.raissi.setups.common.data.factored_label import LabelInfo from i6_experiments.users.raissi.setups.common.decoder.BASE_factored_hybrid_search import BASEFactoredHybridDecoder from i6_experiments.users.raissi.setups.common.decoder.config import PriorInfo, PosteriorScales, SearchParameters +from i6_experiments.users.raissi.setups.common.helpers.network.frame_rate import FrameRateReductionRatioinfo from i6_experiments.users.raissi.setups.common.util.hdf import RasrFeaturesToHdf from i6_experiments.users.raissi.costum.returnn.rasr_returnn_bw import ReturnnRasrTrainingBWJob @@ -124,6 +125,7 @@ def __init__( # general modeling approach self.label_info = LabelInfo.default_ls() + self.frame_rate_reduction_ratio_info = FrameRateReductionRatioinfo.default() self.lexicon_args = get_lexicon_args(norm_pronunciation=False) self.tdp_values = get_tdp_values() diff --git a/users/raissi/setups/common/TF_factored_hybrid_system.py b/users/raissi/setups/common/TF_factored_hybrid_system.py index 6888b0e49..e94f1bff3 100644 --- a/users/raissi/setups/common/TF_factored_hybrid_system.py +++ b/users/raissi/setups/common/TF_factored_hybrid_system.py @@ -224,19 +224,29 @@ def get_blstm_network(self, **kwargs): return network - def get_conformer_network(self, chunking: str, conf_model_dim: int, label_smoothing: float = 0.0, **kwargs): + def get_conformer_network( + self, + chunking: str, + conf_model_dim: int, + frame_rate_reduction_ratio_info: net_helpers.FrameRateReductionRatioinfo, + label_smoothing: float = 0.0, + **kwargs, + ): # this only includes auxilaury losses network_builder = encoder_archs.conformer.get_best_conformer_network( size=conf_model_dim, num_classes=self.label_info.get_n_of_dense_classes(), num_input_feature=self.initial_nn_args["num_input"], + time_tag_name=frame_rate_reduction_ratio_info.time_tag_name, chunking=chunking, label_smoothing=label_smoothing, - additional_args=kwargs, + additional_args=kwargs, # subsampling factor is passed here ) network = network_builder.network if self.training_criterion != TrainingCriterion.fullsum: - network = net_helpers.augment.augment_net_with_label_pops(network, label_info=self.label_info) + network = net_helpers.augment.augment_net_with_label_pops( + network, label_info=self.label_info, frame_rate_reduction_ratio_info=frame_rate_reduction_ratio_info + ) return network # -------------------- Decoding -------------------- @@ -246,6 +256,7 @@ def _set_diphone_joint_state_tying(self): for crp_k in self.crp_names.keys(): if "train" not in crp_k: self._update_crp_am_setting_for_decoding(self.crp_names[crp_k]) + def _compute_returnn_rasr_priors( self, key: str, @@ -420,7 +431,6 @@ def set_single_prior_returnn_rasr( share=data_share, ) - job.add_alias(f"priors/{name}/single_prior") if context_type == PhoneticContext.monophone: p_info = PriorInfo( diff --git a/users/raissi/setups/common/encoder/blstm/layer.py b/users/raissi/setups/common/encoder/blstm/layer.py index 66879cc8a..6fbddde10 100644 --- a/users/raissi/setups/common/encoder/blstm/layer.py +++ b/users/raissi/setups/common/encoder/blstm/layer.py @@ -6,7 +6,7 @@ def blstm_network( dropout: float = 0.1, l2: float = 0.1, specaugment: bool = True, - as_data: bool = False, + spec_aug_as_data: bool = False, transform_func_name: str = "transform", ): num_layers = len(layers) @@ -15,8 +15,8 @@ def blstm_network( result = {} if specaugment: - if as_data: - eval_str = f"self.network.get_config().typed_value('{transform_func_name}')(source(0, as_data={as_data}), network=self.network)" + if spec_aug_as_data: + eval_str = f"self.network.get_config().typed_value('{transform_func_name}')(source(0, as_data={spec_aug_as_data}), network=self.network)" else: eval_str = ( f"self.network.get_config().typed_value('{transform_func_name}')(source(0), network=self.network)" diff --git a/users/raissi/setups/common/encoder/conformer/best_setup.py b/users/raissi/setups/common/encoder/conformer/best_setup.py index 1d808a366..787faec54 100644 --- a/users/raissi/setups/common/encoder/conformer/best_setup.py +++ b/users/raissi/setups/common/encoder/conformer/best_setup.py @@ -8,6 +8,7 @@ get_network_args, ) +from i6_experiments.users.raissi.setups.common.encoder.conformer.layers import DEFAULT_INIT from i6_experiments.users.raissi.setups.common.encoder.conformer.transformer_network import attention_for_hybrid INT_LOSS_LAYER = 6 @@ -34,6 +35,9 @@ def get_best_model_config( label_smoothing: Optional[float] = None, target: str = "classes", time_tag_name: Optional[str] = None, + upsample_by_transposed_conv: bool = True, + feature_stacking_size: int = 3, + weights_init: str = DEFAULT_INIT, additional_args: Optional[dict] = None, ) -> attention_for_hybrid: if int_loss_at_layer is None: @@ -59,31 +63,30 @@ def get_best_model_config( 32, 0.1, 0.0, - **{ - "relative_pe": True, - "clipping": clipping, - "layer_norm_instead_of_batch_norm": True, - }, + clipping=clipping, + layer_norm_instead_of_batch_norm=True, + relative_pe=True, + initialization=weights_init, ) - loss6_down_up_3_two_vggs_args = { + args = { "add_blstm_block": False, "add_conv_block": True, "loss_layer_idx": int_loss_at_layer, "loss_scale": int_loss_scale, "feature_stacking": True, - "feature_stacking_window": [2, 0], - "feature_stacking_stride": 3, - "transposed_conv": True, + "feature_stacking_window": [feature_stacking_size - 1, 0], + "feature_stacking_stride": feature_stacking_size, + "transposed_conv": upsample_by_transposed_conv, "transposed_conv_args": { "time_tag_name": time_tag_name, }, } if additional_args is not None: - loss6_down_up_3_two_vggs_args.update(**additional_args) + args.update(**additional_args) - pe400_conformer_down_up_3_loss6_args = get_network_args( + configured_args = get_network_args( num_enc_layers=12, type="conformer", enc_args=pe400_enc_args, @@ -91,10 +94,10 @@ def get_best_model_config( num_classes=num_classes, num_input_feature=num_input_feature, label_smoothing=label_smoothing, - **loss6_down_up_3_two_vggs_args, + **args, ) - pe400_conformer_layer_norm_down_up_3_loss6 = attention_for_hybrid(**pe400_conformer_down_up_3_loss6_args) - pe400_conformer_layer_norm_down_up_3_loss6.get_network() + conformer = attention_for_hybrid(**configured_args) + conformer.get_network() - return pe400_conformer_layer_norm_down_up_3_loss6 + return conformer diff --git a/users/raissi/setups/common/encoder/conformer/conformer.py b/users/raissi/setups/common/encoder/conformer/conformer.py index 7f0c4c49a..d1abe4459 100644 --- a/users/raissi/setups/common/encoder/conformer/conformer.py +++ b/users/raissi/setups/common/encoder/conformer/conformer.py @@ -1,11 +1,12 @@ __all__ = ["get_best_conformer_network"] -from typing import Optional, Union +from typing import Any, Optional, Union from i6_experiments.users.raissi.setups.common.encoder.conformer.best_setup import get_best_model_config, Size -from i6_experiments.users.raissi.setups.common.helpers.network.augment import Network +from i6_experiments.users.raissi.setups.common.encoder.conformer.transformer_network import attention_for_hybrid from i6_experiments.users.raissi.setups.common.helpers.train import returnn_time_tag +from i6_experiments.users.raissi.setups.common.encoder.conformer.layers import DEFAULT_INIT def get_best_conformer_network( @@ -20,8 +21,11 @@ def get_best_conformer_network( label_smoothing: float = 0.0, leave_cart_output: bool = False, target: str = "classes", - additional_args: dict = None, -) -> Network: + upsample_by_transposed_conv: bool = True, + feature_stacking_size: int = 3, + weights_init: str = DEFAULT_INIT, + additional_args: Optional[Any] = None, +) -> attention_for_hybrid: if time_tag_name is None: _, time_tag_name = returnn_time_tag.get_shared_time_tag() conformer_net = get_best_model_config( @@ -34,14 +38,16 @@ def get_best_conformer_network( label_smoothing=label_smoothing, target=target, time_tag_name=time_tag_name, + upsample_by_transposed_conv=upsample_by_transposed_conv, + feature_stacking_size=feature_stacking_size, + weights_init=weights_init, additional_args=additional_args, ) if not leave_cart_output: - conformer_net.network.pop("output", None) - conformer_net.network["encoder-output"] = { - "class": "copy", - "from": "length_masked", - } + cart_out = conformer_net.network.pop("output") + last_layer = cart_out["from"][0] + + conformer_net.network["encoder-output"] = {"class": "copy", "from": last_layer} return conformer_net diff --git a/users/raissi/setups/common/encoder/conformer/transformer_network.py b/users/raissi/setups/common/encoder/conformer/transformer_network.py index 26b364b7e..947bb74ef 100644 --- a/users/raissi/setups/common/encoder/conformer/transformer_network.py +++ b/users/raissi/setups/common/encoder/conformer/transformer_network.py @@ -28,7 +28,7 @@ def __init__( focal_loss_factor=2.0, softmax_dropout=0.0, use_spec_augment=True, - spec_aug_as_data=True, + spec_aug_as_data=False, use_pos_encoding=False, add_to_input=True, src_embed_args=None, @@ -64,7 +64,6 @@ def __init__( assert type in ["transformer", "conformer"] # TODO: attention window left and right - if type == "transformer": enc_args.pop("kernel_size", None) enc_args.pop("conv_post_dropout", None) @@ -197,8 +196,11 @@ def __init__( if (feature_stacking and feature_stacking_stride >= 2) or ( reduction_factor and reduction_factor[0] * reduction_factor[1] >= 2 ): - assert alignment_reduction or transposed_conv or frame_repetition - assert (alignment_reduction + transposed_conv + frame_repetition) == 1 + # Old asserts from when everything was upsampled + # + # assert alignment_reduction or transposed_conv or frame_repetition + # assert (alignment_reduction + transposed_conv + frame_repetition) == 1 + pass else: alignment_reduction = transposed_conv = frame_repetition = False @@ -488,12 +490,12 @@ def _conv_block(self, inp=None, prefix=""): if self.conv_args: for name in [ - "conv0_0", - "conv0_1", - "conv0p", - "conv1_0", - "conv1_1", - "conv1p", + f"{prefix}conv0_0", + f"{prefix}conv0_1", + f"{prefix}conv0p", + f"{prefix}conv1_0", + f"{prefix}conv1_1", + f"{prefix}conv1p", ]: if self.conv_args.get(name, None): self.network[name].update(self.conv_args.pop(name)) diff --git a/users/raissi/setups/common/helpers/network/__init__.py b/users/raissi/setups/common/helpers/network/__init__.py index 45841a51f..5e3cfd3ad 100644 --- a/users/raissi/setups/common/helpers/network/__init__.py +++ b/users/raissi/setups/common/helpers/network/__init__.py @@ -2,3 +2,4 @@ from .aux_loss import * from .extern_data import * from .diphone_joint_output import * +from .frame_rate import * diff --git a/users/raissi/setups/common/helpers/network/augment.py b/users/raissi/setups/common/helpers/network/augment.py index db6cfd1d7..0b66722cd 100644 --- a/users/raissi/setups/common/helpers/network/augment.py +++ b/users/raissi/setups/common/helpers/network/augment.py @@ -11,8 +11,11 @@ LabelInfo, PhonemeStateClasses, PhoneticContext, + RasrStateTying, ) +from i6_experiments.users.raissi.setups.common.helpers.network.frame_rate import FrameRateReductionRatioinfo + DEFAULT_INIT = "variance_scaling_initializer(mode='fan_in', distribution='uniform', scale=0.78)" Layer = Dict[str, Any] @@ -71,15 +74,16 @@ def pop_phoneme_state_classes( network: Network, labeling_input: str, remaining_classes: int, + prefix: str = "", ) -> Tuple[Network, str, int]: if label_info.phoneme_state_classes == PhonemeStateClasses.boundary: - class_layer_name = "boundaryClass" - labeling_output = "popBoundary" + class_layer_name = f"{prefix}boundaryClass" + labeling_output = f"{prefix}popBoundary" # continues below elif label_info.phoneme_state_classes == PhonemeStateClasses.word_end: - class_layer_name = "wordEndClass" - labeling_output = "popWordEnd" + class_layer_name = f"{prefix}wordEndClass" + labeling_output = f"{prefix}popWordEnd" # continues below elif label_info.phoneme_state_classes == PhonemeStateClasses.none: @@ -108,33 +112,60 @@ def pop_phoneme_state_classes( return network, labeling_output, rem_dim -def augment_net_with_label_pops(network: Network, label_info: LabelInfo) -> Network: - labeling_input = "data:classes" +def augment_net_with_label_pops( + network: Network, + label_info: LabelInfo, + frame_rate_reduction_ratio_info: FrameRateReductionRatioinfo, + prefix: str = "", + labeling_input: str = "data:classes", +) -> Network: + assert label_info.state_tying in [RasrStateTying.diphone, RasrStateTying.triphone] + remaining_label_dim = label_info.get_n_of_dense_classes() network = copy.deepcopy(network) - network["futureLabel"] = { - "class": "eval", - "from": labeling_input, - "eval": f"tf.math.floormod(source(0), {label_info.n_contexts})", - "register_as_extern_data": "futureLabel", - "out_type": {"dim": label_info.n_contexts, "dtype": "int32", "sparse": True}, - } - remaining_label_dim //= label_info.n_contexts - network["popFutureLabel"] = { - "class": "eval", - "from": labeling_input, - "eval": f"tf.math.floordiv(source(0), {label_info.n_contexts})", - "out_type": {"dim": remaining_label_dim, "dtype": "int32", "sparse": True}, - } - labeling_input = "popFutureLabel" + if frame_rate_reduction_ratio_info.factor > 1: + # This layer sets the time step ratio between the input and the output of the NN. + + frr_factors = ( + [frame_rate_reduction_ratio_info.factor] + if isinstance(frame_rate_reduction_ratio_info.factor, int) + else frame_rate_reduction_ratio_info.factor + ) + t_tag = f"{frame_rate_reduction_ratio_info.time_tag_name}" + for factor in frr_factors: + t_tag += f".ceildiv_right({factor})" + + network[f"{prefix}classes_"] = { + "class": "reinterpret_data", + "set_dim_tags": {"T": returnn.CodeWrapper(t_tag)}, + "from": labeling_input, + } + labeling_input = f"{prefix}classes_" + + if label_info.state_tying == RasrStateTying.triphone: + network[f"{prefix}futureLabel"] = { + "class": "eval", + "from": labeling_input, + "eval": f"tf.math.floormod(source(0), {label_info.n_contexts})", + "register_as_extern_data": f"{prefix}futureLabel", + "out_type": {"dim": label_info.n_contexts, "dtype": "int32", "sparse": True}, + } + remaining_label_dim //= label_info.n_contexts + network[f"{prefix}popFutureLabel"] = { + "class": "eval", + "from": labeling_input, + "eval": f"tf.math.floordiv(source(0), {label_info.n_contexts})", + "out_type": {"dim": remaining_label_dim, "dtype": "int32", "sparse": True}, + } + labeling_input = f"{prefix}popFutureLabel" - network["pastLabel"] = { + network[f"{prefix}pastLabel"] = { "class": "eval", "from": labeling_input, "eval": f"tf.math.floormod(source(0), {label_info.n_contexts})", - "register_as_extern_data": "pastLabel", + "register_as_extern_data": f"{prefix}pastLabel", "out_type": {"dim": label_info.n_contexts, "dtype": "int32", "sparse": True}, } @@ -142,20 +173,24 @@ def augment_net_with_label_pops(network: Network, label_info: LabelInfo) -> Netw assert remaining_label_dim == label_info.get_n_state_classes() # popPastLabel in disguise, the label order makes it so that this is directly the center state - network["centerState"] = { + network[f"{prefix}centerState"] = { "class": "eval", "from": labeling_input, "eval": f"tf.math.floordiv(source(0), {label_info.n_contexts})", - "register_as_extern_data": "centerState", + "register_as_extern_data": f"{prefix}centerState", "out_type": {"dim": remaining_label_dim, "dtype": "int32", "sparse": True}, } - labeling_input = "centerState" + labeling_input = f"{prefix}centerState" network, labeling_input, remaining_label_dim = pop_phoneme_state_classes( - label_info, network, labeling_input, remaining_label_dim + label_info, + network, + labeling_input, + remaining_label_dim, + prefix=prefix, ) - network["stateId"] = { + network[f"{prefix}stateId"] = { "class": "eval", "from": labeling_input, "eval": f"tf.math.floormod(source(0), {label_info.n_states_per_phone})", @@ -169,7 +204,7 @@ def augment_net_with_label_pops(network: Network, label_info: LabelInfo) -> Netw remaining_label_dim //= label_info.n_states_per_phone assert remaining_label_dim == label_info.n_contexts - network["centerPhoneme"] = { + network[f"{prefix}centerPhoneme"] = { "class": "eval", "from": labeling_input, "eval": f"tf.math.floordiv(source(0), {label_info.n_states_per_phone})", @@ -191,12 +226,14 @@ def augment_net_with_monophone_outputs( add_mlps=True, use_multi_task=True, final_ctx_type: Optional[PhoneticContext] = None, + focal_loss_factor=2.0, label_smoothing=0.0, l2=None, encoder_output_layer: str = "encoder-output", prefix: str = "", loss_scale=1.0, shared_delta_encoder=False, + weights_init: str = DEFAULT_INIT, ) -> Network: assert ( encoder_output_layer in shared_network @@ -206,7 +243,9 @@ def augment_net_with_monophone_outputs( network = copy.copy(shared_network) encoder_out_len = encoder_output_len - loss_opts = {"focal_loss_factor": 2.0} + loss_opts = {} + if focal_loss_factor > 0.0: + loss_opts["focal_loss_factor"] = focal_loss_factor if label_smoothing > 0.0: loss_opts["label_smoothing"] = label_smoothing @@ -220,12 +259,12 @@ def augment_net_with_monophone_outputs( prefix=prefix, source_layer=encoder_output_layer, l2=l2, + init=weights_init, ) network[f"{prefix}center-output"] = { "class": "softmax", "from": tri_mlp, - "n_out": label_info.get_n_state_classes(), "target": "centerState", "loss": "ce", "loss_opts": copy.copy(loss_opts), @@ -239,6 +278,7 @@ def augment_net_with_monophone_outputs( prefix=prefix, source_layer=encoder_output_layer, l2=l2, + init=weights_init, ) network[f"{prefix}right-output"] = { "class": "softmax", @@ -267,13 +307,13 @@ def augment_net_with_monophone_outputs( prefix=prefix, source_layer=encoder_output_layer, l2=l2, + init=weights_init, ) network[f"{prefix}center-output"] = { "class": "softmax", "from": di_mlp, "target": "centerState", - "n_out": label_info.get_n_state_classes(), "loss": "ce", "loss_opts": copy.copy(loss_opts), } @@ -287,6 +327,7 @@ def augment_net_with_monophone_outputs( prefix=prefix, source_layer=encoder_output_layer, l2=l2, + init=weights_init, ) tri_mlp = add_mlp( network, @@ -295,6 +336,7 @@ def augment_net_with_monophone_outputs( prefix=prefix, source_layer=encoder_output_layer, l2=l2, + init=weights_init, ) network[f"{prefix}left-output"] = { @@ -326,6 +368,7 @@ def augment_net_with_monophone_outputs( prefix=prefix, source_layer=encoder_output_layer, l2=l2, + init=weights_init, ) di_mlp = add_mlp( network, @@ -334,6 +377,7 @@ def augment_net_with_monophone_outputs( prefix=prefix, source_layer=encoder_output_layer, l2=l2, + init=weights_init, ) tri_mlp = add_mlp( network, @@ -342,6 +386,7 @@ def augment_net_with_monophone_outputs( prefix=prefix, source_layer=encoder_output_layer, l2=l2, + init=weights_init, ) network[f"{prefix}center-output"] = { @@ -382,10 +427,24 @@ def augment_net_with_monophone_outputs( l2=l2, source_layer=encoder_output_layer, ) - di_mlp = add_mlp(network, "diphone", di_out, source_layer=delta_blstm_n, l2=l2) + di_mlp = add_mlp( + network, + "diphone", + di_out, + source_layer=delta_blstm_n, + l2=l2, + init=weights_init, + ) else: add_delta_blstm_(network, name=delta_blstm_n, l2=l2, prefix=prefix) - di_mlp = add_mlp(network, "diphone", di_out, l2=l2, prefix=prefix) + di_mlp = add_mlp( + network, + "diphone", + di_out, + l2=l2, + prefix=prefix, + init=weights_init, + ) network[f"{prefix}center-output"] = { "class": "softmax", @@ -404,6 +463,7 @@ def augment_net_with_monophone_outputs( prefix=prefix, source_layer=encoder_output_layer, l2=l2, + init=weights_init, ) if shared_delta_encoder: @@ -414,6 +474,7 @@ def augment_net_with_monophone_outputs( prefix=prefix, source_layer=delta_blstm_n, l2=l2, + init=weights_init, ) else: tri_mlp = add_mlp( @@ -423,6 +484,7 @@ def augment_net_with_monophone_outputs( prefix=prefix, source_layer=delta_blstm_n, l2=l2, + init=weights_init, ) network[f"{prefix}left-output"] = { @@ -448,7 +510,6 @@ def augment_net_with_monophone_outputs( "class": "softmax", "from": encoder_output_layer, "target": "centerState", - "n_out": label_info.get_n_state_classes(), "loss": "ce", "loss_opts": copy.copy(loss_opts), } @@ -492,6 +553,7 @@ def augment_net_with_diphone_outputs( st_emb_size=256, encoder_output_layer: str = "encoder-output", prefix: str = "", + weights_init: str = DEFAULT_INIT, ) -> Network: assert ( encoder_output_layer in shared_network @@ -503,14 +565,11 @@ def augment_net_with_diphone_outputs( if use_multi_task: network["currentState"] = get_embedding_layer(source="centerState", dim=st_emb_size, l2=l2) - network[f"{prefix}linear1-triphone"]["from"] = [ - encoder_output_layer, - "currentState", - ] + network[f"{prefix}linear1-triphone"]["from"] = [encoder_output_layer, "currentState"] else: loss_opts = copy.deepcopy(network[f"{prefix}center-output"]["loss_opts"]) loss_opts["label_smoothing"] = label_smoothing - left_ctx_mlp = add_mlp(network, "leftContext", encoder_output_len, l2=l2, prefix=prefix) + left_ctx_mlp = add_mlp(network, "leftContext", encoder_output_len, l2=l2, prefix=prefix, init=weights_init) network[f"{prefix}left-output"] = { "class": "softmax", "from": left_ctx_mlp, @@ -570,7 +629,9 @@ def augment_net_with_triphone_outputs( return network -def remove_label_pops_and_losses(network: Network, except_layers: Optional[Iterable[str]] = None) -> Network: +def remove_label_pops_and_losses( + network: Network, except_layers: Optional[Iterable[str]] = None +) -> Network: network = copy.copy(network) layers_to_pop = { @@ -609,7 +670,6 @@ def remove_label_pops_and_losses_from_returnn_config( return cfg - def add_fast_bw_layer( crp: rasr.CommonRasrParameters, returnn_config: returnn.ReturnnConfig, diff --git a/users/raissi/setups/common/helpers/network/aux_loss.py b/users/raissi/setups/common/helpers/network/aux_loss.py index 9ec04473c..cd2b7f1b9 100644 --- a/users/raissi/setups/common/helpers/network/aux_loss.py +++ b/users/raissi/setups/common/helpers/network/aux_loss.py @@ -25,30 +25,34 @@ def add_intermediate_loss( scale: float = 0.5, center_state_only: bool = False, final_ctx_type: PhoneticContext = PhoneticContext.triphone_forward, - label_smoothing: float = 0.2, + focal_loss_factor: float = 2.0, + label_smoothing: float = 0.0, l2: float = 0.0, + upsampling: bool = True, ) -> Network: - assert ( - f"aux_{at_layer}_ff1" in network - ), "network needs to be built w/ CART intermediate loss to add FH intermediate loss (upsampling)" - network = copy.deepcopy(network) + prefix = f"aux_{at_layer:03d}_" - network.pop(f"aux_{at_layer}_ff1", None) + old_ff1_layer = network.pop(f"aux_{at_layer}_ff1", None) network.pop(f"aux_{at_layer}_ff2", None) network.pop(f"aux_{at_layer}_output_prob", None) - aux_length = network.pop(f"aux_{at_layer}_length_masked") - input_layer = f"aux_{at_layer:03d}_length_masked" - prefix = f"aux_{at_layer:03d}_" + if upsampling: + assert ( + old_ff1_layer is not None + ), "network needs to be built w/ CART intermediate loss to add FH intermediate loss (upsampling)" - network[input_layer] = { - "class": "slice_nd", - "from": aux_length["from"], - "start": 0, - "size": returnn.CodeWrapper(time_tag_name), - "axis": "T", - } + aux_length = network.pop(f"aux_{at_layer}_length_masked") + input_layer = f"aux_{at_layer:03d}_length_masked" + network[input_layer] = { + "class": "slice_nd", + "from": aux_length["from"], + "start": 0, + "size": returnn.CodeWrapper(time_tag_name), + "axis": "T", + } + else: + input_layer = f"enc_{at_layer:03d}" network = augment_net_with_monophone_outputs( network, @@ -57,6 +61,7 @@ def add_intermediate_loss( final_ctx_type=final_ctx_type, encoder_output_len=encoder_output_len, encoder_output_layer=input_layer, + focal_loss_factor=focal_loss_factor, label_smoothing=label_smoothing, l2=l2, use_multi_task=True, diff --git a/users/raissi/setups/common/helpers/network/frame_rate.py b/users/raissi/setups/common/helpers/network/frame_rate.py new file mode 100644 index 000000000..49e27382c --- /dev/null +++ b/users/raissi/setups/common/helpers/network/frame_rate.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass +from typing import List, Union + +from i6_experiments.users.raissi.setups.common.helpers.train.returnn_time_tag import get_default_time_tag_str + +@dataclass(frozen=True, eq=True) +class FrameRateReductionRatioinfo: + factor: Union[int, List[int]] + time_tag_name: str + + @classmethod + def default(cls) -> "FrameRateReductionRatioinfo": + return FrameRateReductionRatioinfo( + factor=1, + time_tag_name=get_default_time_tag_str() + ) \ No newline at end of file diff --git a/users/raissi/setups/common/helpers/train/network_params.py b/users/raissi/setups/common/helpers/train/network_params.py index 9912ce8cb..02a3ea573 100644 --- a/users/raissi/setups/common/helpers/train/network_params.py +++ b/users/raissi/setups/common/helpers/train/network_params.py @@ -26,11 +26,11 @@ class GeneralNetworkParams: use_multi_task: Optional[bool] = True add_mlps: Optional[bool] = True specaug_args: Optional[dict] = None - subsampling_factor: Optional[int] = 1 + frame_rate_reduction_ratio_factor: Optional[int] = 1 def __post_init__(self): - if self.subsampling_factor > 1: - self.chunking = train_helpers.chunking_with_nfactor(self.chunking, self.subsampling_factor) + if self.frame_rate_reduction_ratio_factor > 1: + self.chunking = train_helpers.chunking_with_nfactor(self.chunking, self.frame_rate_reduction_ratio_factor) # SpecAug params @@ -39,5 +39,4 @@ def __post_init__(self): # no chunking for full-sum default_blstm_fullsum = GeneralNetworkParams(l2=1e-3, use_multi_task=False, add_mlps=False) - default_conformer_viterbi = GeneralNetworkParams(chunking="400:200", l2=1e-6, specaug_args=asdict(default_sa_args)) diff --git a/users/raissi/setups/common/helpers/train/returnn_time_tag.py b/users/raissi/setups/common/helpers/train/returnn_time_tag.py index 2eaa4ee04..ed7ae03a3 100644 --- a/users/raissi/setups/common/helpers/train/returnn_time_tag.py +++ b/users/raissi/setups/common/helpers/train/returnn_time_tag.py @@ -2,13 +2,16 @@ from textwrap import dedent -import typing +from typing import Tuple from i6_core import returnn -def get_shared_time_tag() -> typing.Tuple[str, str]: - var_name = "__time_tag__" +def get_default_time_tag_str() -> str: + return "__time_tag__" + +def get_shared_time_tag() -> Tuple[str, str]: + var_name = get_default_time_tag_str() code = dedent( f""" from returnn.tf.util.data import Dim @@ -24,7 +27,7 @@ def get_context_dim_tag_prolog( context_type: str, spatial_dim_variable_name: str, feature_dim_variable_name: str, -) -> typing.Tuple[str, returnn.CodeWrapper, returnn.CodeWrapper]: +) -> Tuple[str, returnn.CodeWrapper, returnn.CodeWrapper]: code = dedent( f""" from returnn.tf.util.data import FeatureDim, SpatialDim