Skip to content

Commit

Permalink
adjusted the joint returnn config
Browse files Browse the repository at this point in the history
  • Loading branch information
Marvin84 committed Nov 20, 2023
1 parent b438b0c commit 3aa7a4a
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 3 deletions.
30 changes: 30 additions & 0 deletions users/raissi/setups/common/BASE_factored_hybrid_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,13 @@ def prepare_rasr_train_data_with_separate_cv(
return nn_train_data_inputs, nn_cv_data_inputs, nn_devtrain_data_inputs

# -------------------------------------------- Training --------------------------------------------------------
def add_code_to_extra_returnn_code(self, key:str, extra_key: str, extra_dict_key:str, code: str):
# extra_key can be either prolog or epilog
assert extra_dict_key is not None, "set the extra dict key for your additional code"
old_to_update = copy.deepcopy(self.experiments[key]["extra_returnn_code"][extra_key])
old_to_update[extra_dict_key] = code
return old_to_update

def get_config_with_standard_prolog_and_epilog(
self, config: Dict, prolog_additional_str: str = None, epilog_additional_str: str = None
):
Expand Down Expand Up @@ -740,6 +747,29 @@ def set_returnn_config_for_experiment(self, key: str, config_dict: Dict):
self.experiments[key]["extra_returnn_code"]["prolog"] = returnn_config.python_prolog
self.experiments[key]["extra_returnn_code"]["epilog"] = returnn_config.python_epilog

def reset_returnn_config_for_experiment(self, key:str,
config_dict: Dict,
extra_dict_key:str = None,
additional_python_prolog: str = None,
additional_python_epilog: str = None):
if additional_python_prolog is not None:
python_prolog = self.add_code_to_extra_returnn_code(key=key, extra_key="prolog", extra_dict_key=extra_dict_key, code=additional_python_prolog)
else: python_prolog = self.experiments[key]["extra_returnn_code"]["prolog"]

if additional_python_epilog is not None:
python_epilog = self.add_code_to_extra_returnn_code(key=key, extra_key="epilog", extra_dict_key=extra_dict_key, code=additional_python_epilog)
else: python_epilog = self.experiments[key]["extra_returnn_code"]["epilog"]

returnn_config = returnn.ReturnnConfig(
config=config_dict,
hash_full_python_code=True,
python_prolog=python_prolog,
python_epilog=python_epilog,
)
self.experiments[key]["returnn_config"] = returnn_config
self.experiments[key]["extra_returnn_code"]["prolog"] = returnn_config.python_prolog
self.experiments[key]["extra_returnn_code"]["epilog"] = returnn_config.python_epilog

# -------------------------------------------------------------------------

def returnn_training(
Expand Down
35 changes: 33 additions & 2 deletions users/raissi/setups/common/TF_factored_hybrid_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,15 +626,23 @@ def setup_returnn_config_and_graph_for_diphone_joint_decoding(self, key: str=Non
if returnn_config is None:
returnn_config = self.experiments[key]["returnn_config"]
clean_returnn_config = net_helpers.augment.remove_label_pops_and_losses_from_returnn_config(returnn_config)
#used for decoding
context_size = self.label_info.n_contexts
context_time_tag, _, _ = train_helpers.returnn_time_tag.get_context_dim_tag_prolog(spatial_size=context_size,
feature_size=context_size,
spatial_dim_variable_name="__center_state_spatial",
feature_dim_variable_name="__center_state_feature",
context_type='L')

# used for decoding
decoding_returnn_config = net_helpers.diphone_joint_output.augment_to_joint_diphone_softmax(
returnn_config=clean_returnn_config,
label_info=self.label_info,
out_joint_score_layer="output",
log_softmax=True,
)
self.reset_returnn_config_for_experiment(key=key, config_dict=decoding_returnn_config.config, extra_dict_key="context", additional_python_prolog=context_time_tag)
self.set_graph_for_experiment(key)

self.set_graph_for_experiment(key, override_cfg=decoding_returnn_config)
#used for prior estimation
prior_returnn_config = net_helpers.diphone_joint_output.augment_to_joint_diphone_softmax(
returnn_config=clean_returnn_config,
Expand All @@ -645,6 +653,29 @@ def setup_returnn_config_and_graph_for_diphone_joint_decoding(self, key: str=Non

self.experiments[key]["returnn_config"] = prior_returnn_config

def setup_returnn_config_and_graph_for_diphone_joint_prior(self, key: str = None,
returnn_config: returnn.ReturnnConfig = None):

if returnn_config is None:
returnn_config = self.experiments[key]["returnn_config"]
clean_returnn_config = net_helpers.augment.remove_label_pops_and_losses_from_returnn_config(returnn_config)
context_size = self.label_info.n_contexts
context_time_tag = train_helpers.returnn_time_tag.get_context_dim_tag_prolog(spatial_size=context_size,
feature_size=context_size,
spatial_dim_variable_name="__center_state_spatial",
feature_dim_variable_name="__center_state_feature",
context_type='L')
# used for decoding
prior_returnn_config = net_helpers.diphone_joint_output.augment_to_joint_diphone_softmax(
returnn_config=clean_returnn_config,
label_info=self.label_info,
out_joint_score_layer="output",
log_softmax=False,
)
self.reset_returnn_config_for_experiment(key=key, config_dict=prior_returnn_config.config, extra_dict_key="context",
additional_python_prolog=context_time_tag)
self.set_graph_for_experiment(key)


def get_recognizer_and_args(
self,
Expand Down
25 changes: 24 additions & 1 deletion users/raissi/setups/common/helpers/train/returnn_time_tag.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
__all__ = ["get_shared_time_tag"]
__all__ = ["get_shared_time_tag", "get_context_dim_tag_prolog"]


from textwrap import dedent
import typing

from i6_core import returnn


def get_shared_time_tag() -> typing.Tuple[str, str]:
var_name = "__time_tag__"
Expand All @@ -14,3 +16,24 @@ def get_shared_time_tag() -> typing.Tuple[str, str]:
"""
)
return code, var_name


def get_context_dim_tag_prolog(
spatial_size: int,
feature_size: int,
context_type: str,
spatial_dim_variable_name: str,
feature_dim_variable_name: str,
) -> typing.Tuple[str, returnn.CodeWrapper, returnn.CodeWrapper]:
code = dedent(
f"""
from returnn.tf.util.data import FeatureDim, SpatialDim
{spatial_dim_variable_name} = SpatialDim("contexts-{context_type}", {spatial_size})
{feature_dim_variable_name} = FeatureDim("{context_type}", {feature_size})
"""
)
return (
code,
returnn.CodeWrapper(spatial_dim_variable_name),
returnn.CodeWrapper(feature_dim_variable_name),
)

0 comments on commit 3aa7a4a

Please sign in to comment.