Skip to content

Commit

Permalink
refactor(user_dictionary): collect into UserDictEntryIterators
Browse files Browse the repository at this point in the history
  • Loading branch information
lotem committed Oct 4, 2020
1 parent 41e9611 commit bb94093
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 72 deletions.
47 changes: 22 additions & 25 deletions src/rime/dict/user_dictionary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct DfsState {
TickCount present_tick;
Code code;
vector<double> credibility;
an<UserDictEntryCollector> collector;
map<int, DictEntryList> query_result;
an<DbAccessor> accessor;
string key;
string value;
Expand Down Expand Up @@ -70,32 +70,22 @@ void DfsState::RecruitEntry(size_t pos) {
if (e) {
e->code = code;
DLOG(INFO) << "add entry at pos " << pos;
(*collector)[pos].push_back(e);
query_result[pos].push_back(e);
}
}

// UserDictEntryIterator members

void UserDictEntryIterator::Add(const an<DictEntry>& entry) {
if (!entries_) {
entries_ = New<DictEntryList>();
}
entries_->push_back(entry);
void UserDictEntryIterator::Add(an<DictEntry>&& entry) {
cache_.push_back(std::move(entry));
}

void UserDictEntryIterator::SortRange(size_t start, size_t count) {
if (entries_)
entries_->SortRange(start, count);
void UserDictEntryIterator::SetEntries(DictEntryList&& entries) {
cache_ = std::move(entries);
}

bool UserDictEntryIterator::Release(DictEntryList* receiver) {
if (!entries_)
return false;
if (receiver)
entries_->swap(*receiver);
entries_.reset();
index_ = 0;
return true;
void UserDictEntryIterator::SortRange(size_t start, size_t count) {
cache_.SortRange(start, count);
}

void UserDictEntryIterator::AddFilter(DictEntryFilter filter) {
Expand All @@ -111,7 +101,7 @@ an<DictEntry> UserDictEntryIterator::Peek() {
if (exhausted()) {
return nullptr;
}
return (*entries_)[index_];
return cache_[index_];
}

bool UserDictEntryIterator::FindNextEntry() {
Expand Down Expand Up @@ -255,6 +245,14 @@ void UserDictionary::DfsLookup(const SyllableGraph& syll_graph,
}
}

static an<UserDictEntryCollector> collect(map<int, DictEntryList>* source) {
auto result = New<UserDictEntryCollector>();
for (auto& x : *source) {
(*result)[x.first].SetEntries(std::move(x.second));
}
return result;
}

an<UserDictEntryCollector>
UserDictionary::Lookup(const SyllableGraph& syll_graph,
size_t start_pos,
Expand All @@ -268,18 +266,17 @@ UserDictionary::Lookup(const SyllableGraph& syll_graph,
FetchTickCount();
state.present_tick = tick_ + 1;
state.credibility.push_back(initial_credibility);
state.collector = New<UserDictEntryCollector>();
state.accessor = db_->Query("");
state.accessor->Jump(" "); // skip metadata
string prefix;
DfsLookup(syll_graph, start_pos, prefix, &state);
if (state.collector->empty())
if (state.query_result.empty())
return nullptr;
// sort each group of homophones by weight
for (auto& v : *state.collector) {
for (auto& v : state.query_result) {
v.second.Sort();
}
return state.collector;
return collect(&state.query_result);
}

size_t UserDictionary::LookupWords(UserDictEntryIterator* result,
Expand All @@ -289,7 +286,7 @@ size_t UserDictionary::LookupWords(UserDictEntryIterator* result,
string* resume_key) {
TickCount present_tick = tick_ + 1;
size_t len = input.length();
size_t start = result->size();
size_t start = result->cache_size();
size_t count = 0;
size_t exact_match_count = 0;
const string kEnd = "\xff";
Expand Down Expand Up @@ -328,7 +325,7 @@ size_t UserDictionary::LookupWords(UserDictEntryIterator* result,
e->comment = "~" + full_code.substr(len);
e->remaining_code_length = full_code.length() - len;
}
result->Add(e);
result->Add(std::move(e));
++count;
if (is_exact_match)
++exact_match_count;
Expand Down
17 changes: 8 additions & 9 deletions src/rime/dict/user_dictionary.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,33 @@

namespace rime {

struct UserDictEntryCollector : map<size_t, DictEntryList> {
};

class UserDictEntryIterator : public DictEntryFilterBinder {
public:
UserDictEntryIterator() = default;

void Add(const an<DictEntry>& entry);
void Add(an<DictEntry>&& entry);
void SetEntries(DictEntryList&& entries);
void SortRange(size_t start, size_t count);
bool Release(DictEntryList* receiver);

void AddFilter(DictEntryFilter filter) override;
an<DictEntry> Peek();
bool Next();
bool exhausted() const {
return !entries_ || index_ >= entries_->size();
return index_ >= cache_.size();
}
size_t size() const {
return entries_ ? entries_->size() : 0;
size_t cache_size() const {
return cache_.size();
}

protected:
bool FindNextEntry();

an<DictEntryList> entries_;
DictEntryList cache_;
size_t index_ = 0;
};

using UserDictEntryCollector = map<size_t, UserDictEntryIterator>;

class Schema;
class Table;
class Prism;
Expand Down
3 changes: 1 addition & 2 deletions src/rime/gear/poet.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@

#include <rime/common.h>
#include <rime/translation.h>
#include <rime/dict/user_dictionary.h>
#include <rime/gear/translator_commons.h>
#include <rime/gear/contextual_translation.h>

namespace rime {

using WordGraph = map<int, UserDictEntryCollector>;
using WordGraph = map<int, map<int, DictEntryList>>;

class Grammar;
class Language;
Expand Down
60 changes: 34 additions & 26 deletions src/rime/gear/script_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <rime/algo/syllabifier.h>
#include <rime/dict/corrector.h>
#include <rime/dict/dictionary.h>
#include <rime/dict/user_dictionary.h>
#include <rime/gear/poet.h>
#include <rime/gear/script_translator.h>
#include <rime/gear/translator_commons.h>
Expand Down Expand Up @@ -121,9 +122,11 @@ class ScriptTranslation : public Translation {
protected:
bool CheckEmpty();
bool IsNormalSpelling() const;
an<Sentence> MakeSentence(Dictionary* dict,
UserDictionary* user_dict);
void PrepareCandidate();
template <class QueryResult>
void EnrollEntries(map<int, DictEntryList>& entries_by_end_pos,
const an<QueryResult>& query_result);
an<Sentence> MakeSentence(Dictionary* dict, UserDictionary* user_dict);

ScriptTranslator* translator_;
Poet* poet_;
Expand All @@ -138,7 +141,6 @@ class ScriptTranslation : public Translation {

DictEntryCollector::reverse_iterator phrase_iter_;
UserDictEntryCollector::reverse_iterator user_phrase_iter_;
size_t user_phrase_index_ = 0;

size_t max_corrections_ = 4;
size_t correction_count_ = 0;
Expand Down Expand Up @@ -402,10 +404,9 @@ bool ScriptTranslation::Next() {
}
if (user_phrase_code_length > 0 &&
user_phrase_code_length >= phrase_code_length) {
DictEntryList& entries(user_phrase_iter_->second);
if (++user_phrase_index_ >= entries.size()) {
UserDictEntryIterator& uter(user_phrase_iter_->second);
if (!uter.Next()) {
++user_phrase_iter_;
user_phrase_index_ = 0;
}
}
else if (phrase_code_length > 0) {
Expand Down Expand Up @@ -486,8 +487,8 @@ void ScriptTranslation::PrepareCandidate() {
an<Phrase> cand;
if (user_phrase_code_length > 0 &&
user_phrase_code_length >= phrase_code_length) {
DictEntryList& entries(user_phrase_iter_->second);
const auto& entry(entries[user_phrase_index_]);
UserDictEntryIterator& uter = user_phrase_iter_->second;
const auto& entry = uter.Peek();
DLOG(INFO) << "user phrase '" << entry->text
<< "', code length: " << user_phrase_code_length;
cand = New<Phrase>(translator_->language(),
Expand All @@ -500,8 +501,8 @@ void ScriptTranslation::PrepareCandidate() {
(IsNormalSpelling() ? 0.5 : -0.5));
}
else if (phrase_code_length > 0) {
DictEntryIterator& iter(phrase_iter_->second);
const auto& entry(iter.Peek());
DictEntryIterator& iter = phrase_iter_->second;
const auto& entry = iter.Peek();
DLOG(INFO) << "phrase '" << entry->text
<< "', code length: " << phrase_code_length;
cand = New<Phrase>(translator_->language(),
Expand All @@ -522,6 +523,23 @@ bool ScriptTranslation::CheckEmpty() {
return exhausted();
}

template <class QueryResult>
void ScriptTranslation::EnrollEntries(
map<int, DictEntryList>& entries_by_end_pos,
const an<QueryResult>& query_result) {
if (query_result) {
for (auto& y : *query_result) {
DictEntryList& homophones = entries_by_end_pos[y.first];
while (homophones.size() < translator_->max_homophones() &&
!y.second.exhausted()) {
homophones.push_back(y.second.Peek());
if (!y.second.Next())
break;
}
}
}
}

an<Sentence> ScriptTranslation::MakeSentence(Dictionary* dict,
UserDictionary* user_dict) {
const int kMaxSyllablesForUserPhraseQuery = 5;
Expand All @@ -530,23 +548,13 @@ an<Sentence> ScriptTranslation::MakeSentence(Dictionary* dict,
for (const auto& x : syllable_graph.edges) {
auto& same_start_pos = graph[x.first];
if (user_dict) {
auto user_phrase = user_dict->Lookup(syllable_graph, x.first,
kMaxSyllablesForUserPhraseQuery);
if (user_phrase)
same_start_pos.swap(*user_phrase);
}
if (auto phrase = dict->Lookup(syllable_graph, x.first)) {
// merge lookup results
for (auto& y : *phrase) {
DictEntryList& homophones = same_start_pos[y.first];
while (homophones.size() < translator_->max_homophones() &&
!y.second.exhausted()) {
homophones.push_back(y.second.Peek());
if (!y.second.Next())
break;
}
}
EnrollEntries(same_start_pos,
user_dict->Lookup(syllable_graph,
x.first,
kMaxSyllablesForUserPhraseQuery));
}
// merge lookup results
EnrollEntries(same_start_pos, dict->Lookup(syllable_graph, x.first));
}
if (auto sentence =
poet_->MakeSentence(graph,
Expand Down
18 changes: 8 additions & 10 deletions src/rime/gear/table_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,6 @@ class SentenceTranslation : public Translation {
an<Sentence> sentence_;
DictEntryCollector collector_;
UserDictEntryCollector user_phrase_collector_;
size_t user_phrase_index_ = 0;
string input_;
size_t start_;
};
Expand Down Expand Up @@ -441,9 +440,8 @@ bool SentenceTranslation::Next() {
}
if (PreferUserPhrase()) {
auto r = user_phrase_collector_.rbegin();
if (++user_phrase_index_ >= r->second.size()) {
if (!r->second.Next()) {
user_phrase_collector_.erase(r->first);
user_phrase_index_ = 0;
}
}
else {
Expand All @@ -467,7 +465,7 @@ an<Candidate> SentenceTranslation::Peek() {
if (is_user_phrase) {
auto r = user_phrase_collector_.rbegin();
code_length = r->first;
entry = r->second[user_phrase_index_];
entry = r->second.Peek();
}
else {
auto r = collector_.rbegin();
Expand Down Expand Up @@ -603,9 +601,9 @@ TableTranslator::MakeSentence(const string& input, size_t start,
if (include_prefix_phrases && start_pos == 0) {
// also provide words for manual composition
// uter must not be consumed
uter.Release(&user_phrase_collector[consumed_length]);
DLOG(INFO) << "user phrase[" << consumed_length << "]: "
<< user_phrase_collector[consumed_length].size();
user_phrase_collector[consumed_length] = std::move(uter);
DLOG(INFO) << "user phrase[" << consumed_length << "] cached: "
<< user_phrase_collector[consumed_length].cache_size();
}
}
if (resume_key > active_key &&
Expand Down Expand Up @@ -641,9 +639,9 @@ TableTranslator::MakeSentence(const string& input, size_t start,
if (include_prefix_phrases && start_pos == 0) {
// also provide words for manual composition
// uter must not be consumed
uter.Release(&user_phrase_collector[consumed_length]);
DLOG(INFO) << "unity phrase[" << consumed_length << "]: "
<< user_phrase_collector[consumed_length].size();
user_phrase_collector[consumed_length] = std::move(uter);
DLOG(INFO) << "unity phrase[" << consumed_length << "] cached: "
<< user_phrase_collector[consumed_length].cache_size();
}
}
if (resume_key > active_key &&
Expand Down

0 comments on commit bb94093

Please sign in to comment.