From 1dcdf80c587190de9ae189a8c2bfd071e98ca9c9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 26 Jan 2019 13:54:56 -0500 Subject: [PATCH] [src] Fixes to grammar-fst code to handle LM-disambig symbols properly (#3000) thanks: armando.muscariello@gmail.com --- src/decoder/grammar-fst.cc | 124 ++++++++++++++++++++++++++++++++++++- src/decoder/grammar-fst.h | 17 ++--- src/doc/grammar.dox | 2 +- 3 files changed, 132 insertions(+), 11 deletions(-) diff --git a/src/decoder/grammar-fst.cc b/src/decoder/grammar-fst.cc index 6f95993d078..27d8c9998ea 100644 --- a/src/decoder/grammar-fst.cc +++ b/src/decoder/grammar-fst.cc @@ -443,6 +443,98 @@ void GrammarFst::Read(std::istream &is, bool binary) { } +/** + This utility function input-determinizes a specified state s of the FST + 'fst'. (This input-determinizes while treating epsilon as a real symbol, + although for the application we expect to use it, there won't be epsilons). + + What this function does is: for any symbol i that appears as the ilabel of + more than one arc leaving state s of FST 'fst', it creates an additional + state, it creates a new state t with epsilon-input transitions leaving it for + each of those multiple arcs leaving state s; it deletes the original arcs + leaving state s; and it creates a single arc leaving state s to the newly + created state with the ilabel i on it. It sets the weights as necessary to + preserve equivalence and also to ensure that if, prior to this modification, + the FST was stochastic when cast to the log semiring (see + IsStochasticInLog()), it still will be. I.e. when interpreted as + negative logprobs, the weight from state s to t would be the sum of + the weights on the original arcs leaving state s. + + This is used as a very cheap solution when preparing FSTs for the grammar + decoder, to ensure that there is only one entry-state to the sub-FST for each + phonetic left-context; this keeps the grammar-FST code (i.e. the code that + stitches them together) simple. Of course it will tend to introduce + unnecessary epsilons, and if we were careful we might be able to remove + some of those, but this wouldn't have a substantial impact on overall + decoder performance so we don't bother. + */ +static void InputDeterminizeSingleState(StdArc::StateId s, + VectorFst *fst) { + bool was_input_deterministic = true; + typedef StdArc Arc; + typedef Arc::StateId StateId; + typedef Arc::Label Label; + typedef Arc::Weight Weight; + + struct InfoForIlabel { + std::vector arc_indexes; // indexes of all arcs with this ilabel + float tot_cost; // total cost of all arcs leaving state s for this + // ilabel, summed as if they were negative log-probs. + StateId new_state; // state-id of new state, if any, that we have created + // to remove duplicate symbols with this ilabel. + InfoForIlabel(): new_state(-1) { } + }; + + std::unordered_map label_map; + + size_t arc_index = 0; + for (ArcIterator > aiter(*fst, s); + !aiter.Done(); aiter.Next(), ++arc_index) { + const Arc &arc = aiter.Value(); + InfoForIlabel &info = label_map[arc.ilabel]; + if (info.arc_indexes.empty()) { + info.tot_cost = arc.weight.Value(); + } else { + info.tot_cost = -kaldi::LogAdd(-info.tot_cost, -arc.weight.Value()); + was_input_deterministic = false; + } + info.arc_indexes.push_back(arc_index); + } + + if (was_input_deterministic) + return; // Nothing to do. + + // 'new_arcs' will contain the modified list of arcs + // leaving state s + std::vector new_arcs; + new_arcs.reserve(arc_index); + arc_index = 0; + for (ArcIterator > aiter(*fst, s); + !aiter.Done(); aiter.Next(), ++arc_index) { + const Arc &arc = aiter.Value(); + Label ilabel = arc.ilabel; + InfoForIlabel &info = label_map[ilabel]; + if (info.arc_indexes.size() == 1) { + new_arcs.push_back(arc); // no changes needed + } else { + if (info.new_state < 0) { + info.new_state = fst->AddState(); + // add arc from state 's' to newly created state. + new_arcs.push_back(Arc(ilabel, 0, Weight(info.tot_cost), + info.new_state)); + } + // add arc from new state to original destination of this arc. + fst->AddArc(info.new_state, Arc(0, arc.olabel, + Weight(arc.weight.Value() - info.tot_cost), + arc.nextstate)); + } + } + fst->DeleteArcs(s); + for (size_t i = 0; i < new_arcs.size(); i++) + fst->AddArc(s, new_arcs[i]); +} + + // This class contains the implementation of the function // PrepareForGrammarFst(), which is declared in grammar-fst.h. class GrammarFstPreparer { @@ -475,6 +567,12 @@ class GrammarFstPreparer { // OK, state s is a special state. FixArcsToFinalStates(s); MaybeAddFinalProbToState(s); + // The following ensures that the start-state of sub-FSTs only has + // a single arc per left-context phone (the graph-building recipe can + // end up creating more than one if there were disambiguation symbols, + // e.g. for langauge model backoff). + if (s == fst_->Start() && IsEntryState(s)) + InputDeterminizeSingleState(s, fst_); } } } @@ -487,7 +585,7 @@ class GrammarFstPreparer { // Returns true if state 's' has at least one arc coming out of it with a // special nonterminal-related ilabel on it (i.e. an ilabel >= - // kNontermBigNumber) + // kNontermBigNumber), and false otherwise. bool IsSpecialState(StateId s) const; // This function verifies that state s does not currently have any @@ -509,6 +607,10 @@ class GrammarFstPreparer { // modify this state (by adding input-epsilon arcs), and false otherwise. bool NeedEpsilons(StateId s) const; + // Returns true if state s (which is expected to be the start state, although we + // don't check this) has arcs with nonterminal symbols #nonterm_begin. + bool IsEntryState(StateId s) const; + // Fixes any final-prob-related problems with this state. The problem we aim // to fix is that there may be arcs with nonterminal symbol #nonterm_end which // transition from this state to a state with non-unit final prob. This @@ -599,6 +701,24 @@ bool GrammarFstPreparer::IsSpecialState(StateId s) const { return false; } +bool GrammarFstPreparer::IsEntryState(StateId s) const { + int32 big_number = kNontermBigNumber, + encoding_multiple = GetEncodingMultiple(nonterm_phones_offset_); + + for (ArcIterator aiter(*fst_, s ); !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + int32 nonterminal = (arc.ilabel - big_number) / + encoding_multiple; + // we check that at least one has label with nonterminal equal to #nonterm_begin... + // in fact they will all have this value if at least one does, and this was checked + // in NeedEpsilons(). + if (nonterminal == kNontermBegin) + return true; + } + return false; +} + + bool GrammarFstPreparer::NeedEpsilons(StateId s) const { // See the documentation for GetCategoryOfArc() for explanation of what these are. @@ -647,7 +767,7 @@ bool GrammarFstPreparer::NeedEpsilons(StateId s) const { if (nonterminal == GetPhoneSymbolFor(kNontermBegin) && s != fst_->Start()) { KALDI_ERR << "#nonterm_begin symbol is present but this is not the " - "first arc. Did you do fstdeterminizestar while compiling?"; + "first state. Did you do fstdeterminizestar while compiling?"; } if (nonterminal == GetPhoneSymbolFor(kNontermEnd)) { if (fst_->NumArcs(arc.nextstate) != 0 || diff --git a/src/decoder/grammar-fst.h b/src/decoder/grammar-fst.h index f66933c132d..b82d7b3bc9f 100644 --- a/src/decoder/grammar-fst.h +++ b/src/decoder/grammar-fst.h @@ -229,14 +229,15 @@ class GrammarFst { an arc-index leaving a particular state in an FST (i.e. an index that we could use to Seek() to the matching arc). - @param [in] fst The FST we are looking for state-indexes for - @param [in] entry_state The state in the FST-- must have arcs with - ilabels decodable as (nonterminal_symbol, left_context_phone). - Will either be the start state (if 'nonterminal_symbol' - corresponds to #nonterm_begin), or an internal state - (if 'nonterminal_symbol' corresponds to #nonterm_reenter). - The arc-indexes of those arcs will be the values - we set in 'phone_to_arc' + @param [in] fst The FST that is being entered (or reentered) + @param [in] entry_state The state in 'fst' which is being entered + (or reentered); will be fst.Start() if it's being + entered. It must have arcs with ilabels decodable as + (nonterminal_symbol, left_context_phone). Will either be the + start state (if 'nonterminal_symbol' corresponds to + #nonterm_begin), or an internal state (if 'nonterminal_symbol' + corresponds to #nonterm_reenter). The arc-indexes of those + arcs will be the values we set in 'phone_to_arc' @param [in] nonterminal_symbol The index in phones.txt of the nonterminal symbol we expect to be encoded in the ilabels of the arcs leaving 'entry_state'. Will either correspond diff --git a/src/doc/grammar.dox b/src/doc/grammar.dox index d1c6f51f349..30396041d22 100644 --- a/src/doc/grammar.dox +++ b/src/doc/grammar.dox @@ -352,7 +352,7 @@ Z_S 243 The special symbols in CLG.fst will be as follows. The following special symbols may appear in any CLG graph, top-level or not: - - When any graph invokes a sub-graph, there will be n arc with an ilabel + - When any graph invokes a sub-graph, there will be an arc with an ilabel (\#nonterm:foo, left-context-phone) representing the user-specified nonterminal and the actual left-context, which will be followed by arcs with ilabels of the form (\#nonterm_reenter,