Skip to content

Commit

Permalink
Pass correction information in lattice node, so we can filter possibl…
Browse files Browse the repository at this point in the history
…y wrong out results.
  • Loading branch information
wengxt committed Mar 23, 2024
1 parent dab2b5d commit d2f8414
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 28 deletions.
1 change: 1 addition & 0 deletions src/libime/pinyin/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

namespace libime {
constexpr float PINYIN_DISTANCE_PENALTY_FACTOR = 1.8;
constexpr int PINYIN_CORRECTION_FUZZY_FACTOR = 10;
}

#endif // _FCITX_LIBIME_PINYIN_CONSTANTS_H_
37 changes: 35 additions & 2 deletions src/libime/pinyin/pinyincontext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,11 +470,15 @@ void PinyinContext::update() {
float max = -std::numeric_limits<float>::max();
auto distancePenalty = d->ime_->model()->unknownPenalty() /
PINYIN_DISTANCE_PENALTY_FACTOR;
// Pull the phrase from lattice, this part is the word that's in the
// dict.
for (const auto &graphNode : graph.nodes(i)) {
auto distance = graph.distanceToEnd(graphNode);
auto adjust = static_cast<float>(distance) * distancePenalty;
for (const auto &latticeNode : d->lattice_.nodes(&graphNode)) {
if (latticeNode.from() == bos) {
if (latticeNode.from() == bos &&
!static_cast<const PinyinLatticeNode &>(latticeNode)
.isCorrection()) {
if (!d->ime_->model()->isNodeUnknown(latticeNode)) {
if (latticeNode.score() < min) {
min = latticeNode.score();
Expand All @@ -492,13 +496,42 @@ void PinyinContext::update() {
}
}
}

// Filter correction word based on score
for (const auto &graphNode : graph.nodes(i)) {
auto distance = graph.distanceToEnd(graphNode);
auto adjust = static_cast<float>(distance) * distancePenalty;
for (const auto &latticeNode : d->lattice_.nodes(&graphNode)) {
if (latticeNode.from() == bos &&
static_cast<const PinyinLatticeNode &>(latticeNode)
.isCorrection()) {
if (d->candidatesSet_.count(latticeNode.word())) {
continue;
}
if ((latticeNode.score() > min &&
latticeNode.score() + d->ime_->maxDistance() >
max) ||
static_cast<const PinyinLatticeNode &>(latticeNode)
.encodedPinyin()
.size() <= 2) {
d->candidates_.push_back(
latticeNode.toSentenceResult(adjust));
d->candidatesSet_.insert(latticeNode.word());
}
}
}
}

// This part is the phrase that's constructable from lattice.
for (const auto &graphNode : graph.nodes(i)) {
auto distance = graph.distanceToEnd(graphNode);
auto adjust = static_cast<float>(distance) * distancePenalty;
for (const auto &latticeNode : d->lattice_.nodes(&graphNode)) {
if (latticeNode.from() != bos &&
latticeNode.score() > min &&
latticeNode.score() + d->ime_->maxDistance() > max) {
latticeNode.score() + d->ime_->maxDistance() > max &&
!static_cast<const PinyinLatticeNode &>(latticeNode)
.anyCorrectionOnPath()) {
auto fullWord = latticeNode.fullWord();
if (d->candidatesSet_.count(fullWord)) {
continue;
Expand Down
20 changes: 19 additions & 1 deletion src/libime/pinyin/pinyindecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
* SPDX-License-Identifier: LGPL-2.1-or-later
*/

#include "libime/pinyin/pinyindecoder.h"
#include "pinyindecoder_p.h"
#include <cmath>

namespace libime {

Expand All @@ -26,6 +26,24 @@ const std::string &PinyinLatticeNode::encodedPinyin() const {
return d_ptr->encodedPinyin_;
}

bool PinyinLatticeNode::isCorrection() const {
if (!d_ptr) {
return false;
}
return d_ptr->isCorrection_;
}

bool PinyinLatticeNode::anyCorrectionOnPath() const {
const auto *pivot = this;
while (pivot) {
if (pivot->isCorrection()) {
return true;
}
pivot = static_cast<PinyinLatticeNode *>(pivot->prev());
}
return false;
}

LatticeNode *PinyinDecoder::createLatticeNodeImpl(
const SegmentGraphBase &graph, const LanguageModelBase *model,
std::string_view word, WordIndex idx, SegmentGraphPath path,
Expand Down
2 changes: 2 additions & 0 deletions src/libime/pinyin/pinyindecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class LIBIMEPINYIN_EXPORT PinyinLatticeNode : public LatticeNode {
virtual ~PinyinLatticeNode();

const std::string &encodedPinyin() const;
bool isCorrection() const;
bool anyCorrectionOnPath() const;

private:
std::unique_ptr<PinyinLatticeNodePrivate> d_ptr;
Expand Down
5 changes: 3 additions & 2 deletions src/libime/pinyin/pinyindecoder_p.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ namespace libime {

class PinyinLatticeNodePrivate : public LatticeNodeData {
public:
PinyinLatticeNodePrivate(std::string_view encodedPinyin)
: encodedPinyin_(encodedPinyin) {}
PinyinLatticeNodePrivate(std::string_view encodedPinyin, bool isCorrection)
: encodedPinyin_(encodedPinyin), isCorrection_(isCorrection) {}

std::string encodedPinyin_;
bool isCorrection_ = false;
};
} // namespace libime

Expand Down
47 changes: 27 additions & 20 deletions src/libime/pinyin/pinyindictionary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/

#include "pinyindictionary.h"
#include "constants.h"
#include "libime/core/datrie.h"
#include "libime/core/lattice.h"
#include "libime/core/lrucache.h"
Expand All @@ -23,14 +24,13 @@
#include <queue>
#include <string>
#include <string_view>
#include <type_traits>

namespace libime {

namespace {
const float fuzzyCost = std::log10(0.5F);
const size_t minimumLongWordLength = 3;
const float invalidPinyinCost = -100.0f;
const float invalidPinyinCost = -100.0F;
const char pinyinHanziSep = '!';

constexpr uint32_t pinyinBinaryFormatMagic = 0x000fc613;
Expand Down Expand Up @@ -194,7 +194,7 @@ size_t fuzzyFactor(PinyinFuzzyFlags flags) {
size_t factor = 0;
if (flags.test(PinyinFuzzyFlag::Correction)) {
flags = flags.unset(PinyinFuzzyFlag::Correction);
factor += 4;
factor += PINYIN_CORRECTION_FUZZY_FACTOR;
}
if (flags != 0) {
factor += 1;
Expand Down Expand Up @@ -354,8 +354,8 @@ void matchWordsOnTrie(const MatchedPinyinPath &path, bool matchLongWord,
float extraCost = fuzzies * fuzzyCost;
if (matchLongWord) {
path.trie()->foreach(
[&path, &callback, extraCost](PinyinTrie::value_type value,
size_t len, uint64_t pos) {
[&path, &callback, extraCost, fuzzies](
PinyinTrie::value_type value, size_t len, uint64_t pos) {
std::string s;
s.reserve(len + path.size() * 2);
path.trie()->suffix(s, len + path.size() * 2, pos);
Expand All @@ -370,7 +370,8 @@ void matchWordsOnTrie(const MatchedPinyinPath &path, bool matchLongWord,
float overLengthCost = fuzzyCost * lengthDiff;

callback(encodedPinyin, hanzi,
value + extraCost + overLengthCost);
value + extraCost + overLengthCost,
fuzzies >= PINYIN_CORRECTION_FUZZY_FACTOR);
}
return true;
},
Expand All @@ -383,15 +384,16 @@ void matchWordsOnTrie(const MatchedPinyinPath &path, bool matchLongWord,
}

path.trie()->foreach(
[&path, &callback, extraCost](PinyinTrie::value_type value,
size_t len, uint64_t pos) {
[&path, &callback, extraCost, fuzzies](
PinyinTrie::value_type value, size_t len, uint64_t pos) {
std::string s;
s.reserve(len + path.size() * 2 + 1);
path.trie()->suffix(s, len + path.size() * 2 + 1, pos);
std::string_view view(s);
auto encodedPinyin = view.substr(0, path.size() * 2);
auto hanzi = view.substr(path.size() * 2 + 1);
callback(encodedPinyin, hanzi, value + extraCost);
callback(encodedPinyin, hanzi, value + extraCost,
fuzzies >= PINYIN_CORRECTION_FUZZY_FACTOR);
return true;
},
pos);
Expand Down Expand Up @@ -421,12 +423,12 @@ bool PinyinDictionaryPrivate::matchWordsForOnePath(
const bool matchLongWord =
(path.path_.back() == &context.graph_.end() && matchLongWordEnabled);

auto foundOneWord = [&path, &prevNode, &matched,
&context](std::string_view encodedPinyin,
WordNode &word, float cost) {
context.callback_(
path.path_, word, cost,
std::make_unique<PinyinLatticeNodePrivate>(encodedPinyin));
auto foundOneWord = [&path, &prevNode, &matched, &context](
std::string_view encodedPinyin, WordNode &word,
float cost, bool isCorrection) {
context.callback_(path.path_, word, cost,
std::make_unique<PinyinLatticeNodePrivate>(
encodedPinyin, isCorrection));
if (path.size() == 1 &&
path.path_[path.path_.size() - 2] == &prevNode) {
matched = true;
Expand All @@ -445,23 +447,28 @@ bool PinyinDictionaryPrivate::matchWordsForOnePath(
auto &items = *result;
matchWordsOnTrie(path, matchLongWordEnabled,
[&items](std::string_view encodedPinyin,
std::string_view hanzi, float cost) {
items.emplace_back(hanzi, cost, encodedPinyin);
std::string_view hanzi, float cost,
bool isCorrection) {
items.emplace_back(hanzi, cost, encodedPinyin,
isCorrection);
});
}
for (auto &item : *result) {
if (!matchLongWord &&
item.encodedPinyin_.size() / 2 > path.size()) {
continue;
}
foundOneWord(item.encodedPinyin_, item.word_, item.value_);
foundOneWord(item.encodedPinyin_, item.word_, item.value_,
item.isCorrection_);
}
} else {
matchWordsOnTrie(path, matchLongWord,
[&foundOneWord](std::string_view encodedPinyin,
std::string_view hanzi, float cost) {
std::string_view hanzi, float cost,
bool isCorrection) {
WordNode word(hanzi, InvalidWordIndex);
foundOneWord(encodedPinyin, word, cost);
foundOneWord(encodedPinyin, word, cost,
isCorrection);
});
}

Expand Down
7 changes: 4 additions & 3 deletions src/libime/pinyin/pinyinmatchstate_p.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,13 @@ struct MatchedPinyinTrieNodes {
// adjustment score.
struct PinyinMatchResult {
PinyinMatchResult(std::string_view s, float value,
std::string_view encodedPinyin)
std::string_view encodedPinyin, bool isCorrection)
: word_(s, InvalidWordIndex), value_(value),
encodedPinyin_(encodedPinyin) {}
encodedPinyin_(encodedPinyin), isCorrection_(isCorrection) {}
WordNode word_;
float value_;
float value_ = 0.0F;
std::string encodedPinyin_;
bool isCorrection_ = false;
};

// class to store current SegmentGraphPath leads to this match and the match
Expand Down

0 comments on commit d2f8414

Please sign in to comment.