Skip to content

Commit

Permalink
Refactor tree construction in AdvancedTreeSearch to allow multiple Tr…
Browse files Browse the repository at this point in the history
…eeBuilder types (#85)


Co-authored-by: Eugen Beck <ebeck@apptek.com>
Co-authored-by: Simon Berger <berger@i6.informatik.rwth-aachen.de>
  • Loading branch information
3 people authored Jan 8, 2025
1 parent 8dfd39b commit 6a9ac9c
Show file tree
Hide file tree
Showing 6 changed files with 808 additions and 700 deletions.
21 changes: 11 additions & 10 deletions src/Search/AdvancedTreeSearch/PersistentStateTree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ struct ConvertTree {
TreeIndex masterTreeIndex;
StateId rootSubTree;
StateId ciRootNode;
std::map<StateTree::Exit, u32> exits; //Maps exits to label-indices @todo Make this a hash_map
std::map<StateTree::Exit, u32> exits; // Maps exits to label-indices @todo Make this a hash_map
std::vector<PersistentStateTree::Exit> exitVector;
Core::HashMap<StateId, StateTree::StateId> statesForNodes;
Core::HashMap<StateTree::StateId, StateId> nodesForStates;
Expand Down Expand Up @@ -73,7 +73,7 @@ struct ConvertTree {
}
}

///Make sure a node is created for every single state, so that also the coarticulated roots are respected
/// Make sure a node is created for every single state, so that also the coarticulated roots are respected

for (std::set<StateTree::StateId>::iterator stateIt = coarticulatedRootStates.begin(); stateIt != coarticulatedRootStates.end(); ++stateIt) {
StateTree::StateId state = *stateIt;
Expand Down Expand Up @@ -121,7 +121,7 @@ struct ConvertTree {
exitIndices.insert(exitEntry->second);
}

//Add connections to the attached outputs/exits
// Add connections to the attached outputs/exits
for (std::set<u32>::iterator it = exitIndices.begin(); it != exitIndices.end(); ++it)
subtrees.addOutputToEdge(subtrees.state(node).successors, *it);
}
Expand Down Expand Up @@ -150,10 +150,10 @@ struct ConvertTree {

subtrees.state(node).stateDesc = state;

//Build successor structure
// Build successor structure
std::pair<StateTree::SuccessorIterator, StateTree::SuccessorIterator> successors = tree->successors(stateId);

StateId current = node; //Just to verify the order
StateId current = node; // Just to verify the order

for (; successors.first != successors.second; ++successors.first) {
std::unordered_map<StateTree::StateId, StateId>::iterator nodeIt = nodesForStates.find(*successors.first);
Expand All @@ -166,14 +166,15 @@ struct ConvertTree {
}
};

PersistentStateTree::PersistentStateTree(Core::Configuration config, Core::Ref<const Am::AcousticModel> acousticModel, Bliss::LexiconRef lexicon)
PersistentStateTree::PersistentStateTree(Core::Configuration config, Core::Ref<const Am::AcousticModel> acousticModel, Bliss::LexiconRef lexicon, TreeBuilderFactory treeBuilderFactory)
: masterTree(0),
rootState(0),
ciRootState(0),
archive_(paramCacheArchive(Core::Configuration(config, "search-network"))),
acousticModel_(acousticModel),
lexicon_(lexicon),
config_(config) {
config_(config),
treeBuilderFactory_(treeBuilderFactory) {
if (acousticModel_.get() && lexicon_.get()) {
const Am::ClassicAcousticModel* am = required_cast(const Am::ClassicAcousticModel*, acousticModel.get());
Core::DependencySet d;
Expand Down Expand Up @@ -320,7 +321,7 @@ bool PersistentStateTree::read(Core::MappedArchiveReader in) {
in >> masterTree >> dependenciesChecksum;

if (dependenciesChecksum != dependencies_.getChecksum()) {
Core::Application::us()->log() << "dependencies of the network image don't equal the requiered dependencies with checksum " << dependenciesChecksum;
Core::Application::us()->log() << "dependencies of the network image don't equal the required dependencies with checksum " << dependenciesChecksum;
return false;
}

Expand Down Expand Up @@ -436,7 +437,7 @@ HMMStateNetwork::CleanupResult PersistentStateTree::cleanup(bool cleanupExits) {

Core::HashMap<StateId, StateId>::const_iterator targetNodeIt;
if (rootState) {
verify(cleanupResult.nodeMap.find(rootState) != cleanupResult.nodeMap.end()); //Root-node must stay unchanged
verify(cleanupResult.nodeMap.find(rootState) != cleanupResult.nodeMap.end()); // Root-node must stay unchanged
verify(cleanupResult.nodeMap.find(rootState)->second == rootState);
targetNodeIt = cleanupResult.nodeMap.find(rootState);
verify(targetNodeIt != cleanupResult.nodeMap.end());
Expand Down Expand Up @@ -512,7 +513,7 @@ void PersistentStateTree::dumpDotGraph(std::string file, const std::vector<int>&
int depth = 0;
if (!nodeDepths.empty())
depth = nodeDepths[node];
os << Core::form("n%d [label=\"%d\\nd=%d\\nm=%d", node, node, depth, structure.state(node).stateDesc.acousticModel);
os << Core::form("n%d [label=\"%d\\nd=%d\\nm=%d\\nt=%d", node, node, depth, structure.state(node).stateDesc.acousticModel, structure.state(node).stateDesc.transitionModelIndex);

for (HMMStateNetwork::SuccessorIterator target = structure.successors(node); target; ++target)
if (target.isLabel() && exits[target.label()].pronunciation != Bliss::LemmaPronunciation::invalidId)
Expand Down
29 changes: 17 additions & 12 deletions src/Search/AdvancedTreeSearch/PersistentStateTree.hh
Original file line number Diff line number Diff line change
Expand Up @@ -31,34 +31,38 @@ struct MyStandardValueHash {
}
};

class AbstractTreeBuilder;

namespace Search {
class HMMStateNetwork;
class StateTree;

class PersistentStateTree {
public:
using TreeBuilderFactory = std::function<std::unique_ptr<AbstractTreeBuilder>(Core::Configuration, const Bliss::Lexicon&, const Am::AcousticModel&, PersistentStateTree&, bool)>;

///@param lexicon This must be given if the resulting exits are supposed to be functional
PersistentStateTree(Core::Configuration config, Core::Ref<const Am::AcousticModel> acousticModel, Bliss::LexiconRef lexicon);
PersistentStateTree(Core::Configuration config, Core::Ref<const Am::AcousticModel> acousticModel, Bliss::LexiconRef lexicon, TreeBuilderFactory treeBuilderFactory);

///Builds this state tree.
/// Builds this state tree.
void build();

///Writes the current state of the state tree into the file,
///Returns whether writing was successful
/// Writes the current state of the state tree into the file,
/// Returns whether writing was successful
bool write(int transformation = 0);

///Reads the state tree from the file.
/// Reads the state tree from the file.
///@return Whether the reading was successful.
bool read(int transformation = 0);

///Cleans up the structure, saving memory and allowing a more efficient iteration.
///Node and tree IDs may be changed.
/// Cleans up the structure, saving memory and allowing a more efficient iteration.
/// Node and tree IDs may be changed.
///@return An object that contains a mapping representing the index changes.
HMMStateNetwork::CleanupResult cleanup(bool cleanupExits = true);

///Removes all outputs from the network
///Also performs a cleanup, so the search network must already be clean
///for indices to stay equal
/// Removes all outputs from the network
/// Also performs a cleanup, so the search network must already be clean
/// for indices to stay equal
void removeOutputs();

u32 getChecksum() const;
Expand Down Expand Up @@ -128,11 +132,12 @@ private:
Core::Ref<const Am::AcousticModel> acousticModel_;
Bliss::LexiconRef lexicon_;
Core::Configuration config_;
TreeBuilderFactory treeBuilderFactory_;

//Writes the whole state network into the given stream
// Writes the whole state network into the given stream
void write(Core::MappedArchiveWriter writer);

//Reads the state network from the given stream.
// Reads the state network from the given stream.
//@return Whether the reading was successful.
bool read(Core::MappedArchiveReader reader);
};
Expand Down
Loading

0 comments on commit 6a9ac9c

Please sign in to comment.