Skip to content

Commit

Permalink
prior and decoding for joint to be tested
Browse files Browse the repository at this point in the history
  • Loading branch information
Marvin84 committed Nov 22, 2023
1 parent 537e0a3 commit 23c65cc
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 51 deletions.
26 changes: 22 additions & 4 deletions users/raissi/setups/common/TF_factored_hybrid_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,15 +381,17 @@ def _compute_returnn_rasr_priors_via_hdf(

return prior_job

def set_mono_priors_returnn_rasr(
def set_single_prior_returnn_rasr(
self,
key: str,
epoch: int,
train_corpus_key: str,
dev_corpus_key: str,
context_type: PhoneticContext = PhoneticContext.monophone,
returnn_config: Optional[returnn.ReturnnConfig] = None,
output_layer_name: str = "output",
data_share: float = 0.1,
smoothen: bool = False,
data_share: float = 0.3,
):
self.set_graph_for_experiment(key)

Expand All @@ -411,10 +413,26 @@ def set_mono_priors_returnn_rasr(
share=data_share,
)


job.add_alias(f"priors/{name}/c")
tk.register_output(f"priors/{name}/center-state.xml", job.out_prior_xml_file)

self.experiments[key]["priors"] = [job.out_prior_xml_file]
if context_type == PhoneticContext.monophone:
p_info = PriorInfo(
center_state_prior=PriorConfig(file=job.out_prior_xml_file, scale=0.0),
)
tk.register_output(f"priors/{name}/center-state.xml", p_info.center_state_prior.file)
elif context_type == PhoneticContext.joint_diphone:
p_info = PriorInfo(
diphone_prior=PriorConfig(file=job.out_prior_xml_file, scale=0.0),
)
tk.register_output(f"priors/{name}/joint_diphone.xml", p_info.diphone_prior.file)
else:
raise NotImplementedError("Unknown PhoneticContext, i.e. context_type")

p_info = smoothen_priors(p_info) if smoothen else p_info
self.experiments[key]["priors"] = p_info



def set_diphone_priors_returnn_rasr(
self,
Expand Down
76 changes: 43 additions & 33 deletions users/raissi/setups/common/data/factored_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,39 +42,6 @@ def use_word_end(self) -> bool:
return self == PhonemeStateClasses.word_end


@dataclass(eq=True, frozen=True)
class LabelInfo:
n_contexts: int
n_states_per_phone: int
phoneme_state_classes: PhonemeStateClasses
ph_emb_size: int
st_emb_size: int
state_tying: RasrStateTying
add_unknown_phoneme: bool = True
sil_id: typing.Optional[int] = None

def get_n_of_dense_classes(self) -> int:
n_contexts = self.n_contexts
if not self.add_unknown_phoneme:
n_contexts += 1
return self.n_states_per_phone * (n_contexts**3) * self.phoneme_state_classes.factor()

def get_n_state_classes(self) -> int:
return self.n_states_per_phone * self.n_contexts * self.phoneme_state_classes.factor()

@classmethod
def default_ls(cls) -> "LabelInfo":
return LabelInfo(
n_contexts=42,
n_states_per_phone=3,
ph_emb_size=32,
st_emb_size=128,
add_unknown_phoneme=True,
phoneme_state_classes=PhonemeStateClasses.word_end,
state_tying=RasrStateTying.triphone,
)


class PhoneticContext(Enum):
"""
These are the implemented models. The string value is the one used in the feature scorer of rasr, except monophone
Expand Down Expand Up @@ -126,3 +93,46 @@ def is_triphone(self):
or self == PhoneticContext.triphone_symmetric
or self == PhoneticContext.tri_state_transition
)

@dataclass(eq=True, frozen=True)
class LabelInfo:
n_contexts: int
n_states_per_phone: int
phoneme_state_classes: PhonemeStateClasses
ph_emb_size: int
st_emb_size: int
state_tying: RasrStateTying

add_unknown_phoneme: bool = True
sil_id: typing.Optional[int] = None

def get_n_of_dense_classes(self) -> int:
if self.state_tying == RasrStateTying.monophone:
exp = 1
elif self.state_tying == RasrStateTying.diphone:
exp = 2
elif self.state_tying == RasrStateTying.triphone:
exp = 3
else:
assert False, "cannot compute number of CART classes"

n_contexts = self.n_contexts
if not self.add_unknown_phoneme:
n_contexts += 1
return self.n_states_per_phone * (n_contexts**exp) * self.phoneme_state_classes.factor()

def get_n_state_classes(self) -> int:
return self.n_states_per_phone * self.n_contexts * self.phoneme_state_classes.factor()

@classmethod
def default_ls(cls) -> "LabelInfo":
return LabelInfo(
n_contexts=42,
n_states_per_phone=3,
ph_emb_size=32,
st_emb_size=128,
sil_id=40,
add_unknown_phoneme=True,
phoneme_state_classes=PhonemeStateClasses.word_end,
state_tying=RasrStateTying.triphone,
)
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def __str__(self):
return self.value

def get_fs_class(self):
if self.factored:
if self == self.factored:
return FactoredHybridFeatureScorer
elif self.nn_precomputed:
elif self == self.nn_precomputed:
return rasr.PrecomputedHybridFeatureScorer
else:
raise ValueError("Unknown type of Feature Scorer")
Expand Down Expand Up @@ -232,10 +232,11 @@ def get_nn_precomputed_feature_scorer(
if isinstance(prior_info, PriorInfo):
check_prior_info(context_type=context_type, prior_info=prior_info)


return feature_scorer_type.get_fs_class()(
prior_mixtures=mixtures,
prior_file=prior_info.joint_diphone.file,
priori_scale=prior_info.joint_diphone.scale,
priori_scale=prior_info.diphone_prior.scale,
prior_file=prior_info.diphone_prior.file,
)


Expand Down Expand Up @@ -356,21 +357,21 @@ def set_tf_fs_flow(self):
raise NotImplementedError

def set_nnprecomputed_tf_fs_flow(self):
tf_flow = self.get_nnprecomputed_tf_flow()
self.feature_scorer_flow = self.get_nnprecomputed_tf_flow()
"""
tf_feature_flow = add_tf_flow_to_base_flow(
base_flow=self.feature_path,
tf_flow=tf_flow,
tf_fwd_input_name=tf_fwd_input_name,
)
self.feature_scorer_flow = tf_feature_flow
self.feature_scorer_flow = tf_feature_flow"""

def get_nnprecomputed_tf_flow(self):
return make_precomputed_hybrid_tf_feature_flow(
tf_graph=self.graph,
tf_checkpoint=self.model_path,
output_layer_name=self.tensor_map.out_joint_diphone,
native_ops=self.tf_library,
tf_fwd_input_name=tf_fwd_input_name,
native_ops=self.library_path,
)

def set_factored_tf_fs_flow(self):
Expand Down
45 changes: 39 additions & 6 deletions users/raissi/setups/common/decoder/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,15 @@ class PriorInfo:
center state.
"""

center_state_prior: PriorConfig
center_state_prior: Optional[PriorConfig] = None
left_context_prior: Optional[PriorConfig] = None
right_context_prior: Optional[PriorConfig] = None
diphone_prior: Optional[PriorConfig] = None


def with_scale(
self,
center: Float,
center: Optional[Float] = None,
left: Optional[Float] = None,
right: Optional[Float] = None,
diphone: Optional[Float] = None,
Expand All @@ -50,17 +51,18 @@ def with_scale(
left/right/diphone scale must be set if the left/right priors are set.
"""

assert self.center_state_prior is None or center is not None
assert self.left_context_prior is None or left is not None
assert self.right_context_prior is None or right is not None
assert self.diphone_prior is None or diphone is not None

center = self.center_state_prior.with_scale(center) if self.center_state_prior is not None else None
left = self.left_context_prior.with_scale(left) if self.left_context_prior is not None else None
right = self.right_context_prior.with_scale(right) if self.right_context_prior is not None else None
diphone = self.diphone_prior.with_scale(diphone) if self.diphone_prior is not None else None
return dataclasses.replace(
self,
center_state_prior=self.center_state_prior.with_scale(center),
center_state_prior=center,
left_context_prior=left,
right_context_prior=right,
diphone_prior=diphone,
Expand Down Expand Up @@ -104,6 +106,16 @@ def from_triphone_job(cls, output_dir: Union[str, tk.Path]) -> "PriorInfo":
left_context_prior=PriorConfig(file=output_dir.join_right("left-context.xml"), scale=0.0),
right_context_prior=PriorConfig(file=output_dir.join_right("right-context.xml"), scale=0.0),
)
def from_diiphone_job(cls, output_dir: Union[str, tk.Path]) -> "PriorInfo":
"""
Initializes a PriorInfo instance with scale 0.0 from the output directory of
a previously run/captured ComputeTriphoneForwardPriorsJob.
"""

output_dir = tk.Path(output_dir) if isinstance(output_dir, str) else output_dir
return cls(
diphone_prior=PriorConfig(file=output_dir.join_right("center-state.xml"), scale=0.0),
)


PosteriorScales = TypedDict(
Expand Down Expand Up @@ -160,11 +172,12 @@ def with_lm_scale(self, scale: Float) -> "SearchParameters":

def with_prior_scale(
self,
center: Float,
center: Optional[Float] = None,
left: Optional[Float] = None,
right: Optional[Float] = None,
diphone: Optional[Float] = None,
) -> "SearchParameters":
return dataclasses.replace(self, prior_info=self.prior_info.with_scale(center=center, left=left, right=right))
return dataclasses.replace(self, prior_info=self.prior_info.with_scale(center=center, left=left, right=right, diphone=diphone))

def with_pron_scale(self, pron_scale: Float) -> "SearchParameters":
return dataclasses.replace(self, pron_scale=pron_scale)
Expand Down Expand Up @@ -232,6 +245,23 @@ def default_triphone(cls, *, priors: PriorInfo) -> "SearchParameters":
we_pruning_limit=10000,
)

@classmethod
def default_joint_diphone(cls, *, priors: PriorInfo) -> "SearchParameters":
return cls(
beam=20,
beam_limit=500_000,
lm_scale=9.0,
tdp_scale=0.4,
pron_scale=2.0,
prior_info=priors.with_scale(diphone=0.4),
tdp_speech=(3.0, 0.0, "infinity", 0.0),
tdp_silence=(0.0, 3.0, "infinity", 20.0),
tdp_non_word=(0.0, 3.0, "infinity", 20.0),
non_word_phonemes="[UNKNOWN]",
we_pruning=0.5,
we_pruning_limit=10000,
)

@classmethod
def default_cart(cls, *, priors: PriorInfo) -> "SearchParameters":
return dataclasses.replace(
Expand All @@ -249,5 +279,8 @@ def default_for_ctx(cls, context: PhoneticContext, priors: PriorInfo) -> "Search
return cls.default_diphone(priors=priors)
elif context == PhoneticContext.triphone_forward:
return cls.default_triphone(priors=priors)
elif context == PhoneticContext.joint_diphone:
return cls.default_joint_diphone(priors=priors)

else:
raise NotImplementedError(f"unimplemented context {context}")

0 comments on commit 23c65cc

Please sign in to comment.