Skip to content

Commit

Permalink
[feature](inverted index) add ordered functionality to match_phrase q…
Browse files Browse the repository at this point in the history
…uery (#37751)

## Proposed changes

1. select count() from tbl where b match_phrase 'the brown ~2+';
  • Loading branch information
zzzxl1993 authored Jul 15, 2024
1 parent d1fc4e2 commit b7dbd5c
Show file tree
Hide file tree
Showing 6 changed files with 390 additions and 98 deletions.
253 changes: 173 additions & 80 deletions be/src/olap/rowset/segment_v2/inverted_index/query/phrase_query.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,105 @@

namespace doris::segment_v2 {

template <typename Derived>
bool PhraseMatcherBase<Derived>::matches(int32_t doc) {
reset(doc);
return static_cast<Derived*>(this)->next_match();
}

template <typename Derived>
void PhraseMatcherBase<Derived>::reset(int32_t doc) {
for (PostingsAndPosition& posting : _postings) {
if (posting._postings.docID() != doc) {
posting._postings.advance(doc);
}
posting._freq = posting._postings.freq();
posting._pos = -1;
posting._upTo = 0;
}
}

template <typename Derived>
bool PhraseMatcherBase<Derived>::advance_position(PostingsAndPosition& posting, int32_t target) {
while (posting._pos < target) {
if (posting._upTo == posting._freq) {
return false;
} else {
posting._pos = posting._postings.nextPosition();
posting._upTo += 1;
}
}
return true;
}

bool ExactPhraseMatcher::next_match() {
PostingsAndPosition& lead = _postings[0];
if (lead._upTo < lead._freq) {
lead._pos = lead._postings.nextPosition();
lead._upTo += 1;
} else {
return false;
}

while (true) {
int32_t phrasePos = lead._pos - lead._offset;

bool advance_head = false;
for (size_t j = 1; j < _postings.size(); ++j) {
PostingsAndPosition& posting = _postings[j];
int32_t expectedPos = phrasePos + posting._offset;
// advance up to the same position as the lead
if (!advance_position(posting, expectedPos)) {
return false;
}

if (posting._pos != expectedPos) { // we advanced too far
if (advance_position(lead, posting._pos - posting._offset + lead._offset)) {
advance_head = true;
break;
} else {
return false;
}
}
}
if (advance_head) {
continue;
}

return true;
}

return false;
}

bool OrderedSloppyPhraseMatcher::next_match() {
PostingsAndPosition* prev_posting = _postings.data();
while (prev_posting->_upTo < prev_posting->_freq) {
prev_posting->_pos = prev_posting->_postings.nextPosition();
prev_posting->_upTo += 1;
if (stretch_to_order(prev_posting) && _match_width <= _allowed_slop) {
return true;
}
}
return false;
}

bool OrderedSloppyPhraseMatcher::stretch_to_order(PostingsAndPosition* prev_posting) {
_match_width = 0;
for (size_t i = 1; i < _postings.size(); i++) {
PostingsAndPosition& posting = _postings[i];
if (!advance_position(posting, prev_posting->_pos + 1)) {
return false;
}
_match_width += (posting._pos - (prev_posting->_pos + 1));
prev_posting = &posting;
}
return true;
}

PhraseQuery::PhraseQuery(const std::shared_ptr<lucene::search::IndexSearcher>& searcher,
const TQueryOptions& query_options)
: _searcher(searcher), _query(std::make_unique<CL_NS(search)::PhraseQuery>()) {}
: _searcher(searcher) {}

PhraseQuery::~PhraseQuery() {
for (auto& term_doc : _term_docs) {
Expand All @@ -44,16 +140,20 @@ void PhraseQuery::add(const InvertedIndexQueryInfo& query_info) {
}

_slop = query_info.slop;
if (_slop <= 0) {
if (_slop == 0 || query_info.ordered) {
// Logic for no slop query and ordered phrase query
add(query_info.field_name, query_info.terms);
} else {
// Simple slop query follows the default phrase query algorithm
auto query = std::make_unique<CL_NS(search)::PhraseQuery>();
for (const auto& term : query_info.terms) {
std::wstring ws_term = StringUtil::string_to_wstring(term);
auto* t = _CLNEW lucene::index::Term(query_info.field_name.c_str(), ws_term.c_str());
_query->add(t);
query->add(t);
_CLDECDELETE(t);
}
_query->setSlop(_slop);
query->setSlop(_slop);
_matcher = std::move(query);
}
}

Expand All @@ -73,14 +173,33 @@ void PhraseQuery::add(const std::wstring& field_name, const std::vector<std::str
}

std::vector<TermIterator> iterators;
for (size_t i = 0; i < terms.size(); i++) {
std::wstring ws_term = StringUtil::string_to_wstring(terms[i]);
auto ensureTermPosition = [this, &iterators, &field_name](const std::string& term) {
std::wstring ws_term = StringUtil::string_to_wstring(term);
Term* t = _CLNEW Term(field_name.c_str(), ws_term.c_str());
_terms.push_back(t);
TermPositions* term_pos = _searcher->getReader()->termPositions(t);
_term_docs.push_back(term_pos);
iterators.emplace_back(term_pos);
_postings.emplace_back(term_pos, i);
return term_pos;
};

if (_slop == 0) {
ExactPhraseMatcher matcher;
for (size_t i = 0; i < terms.size(); i++) {
const auto& term = terms[i];
auto* term_pos = ensureTermPosition(term);
matcher._postings.emplace_back(term_pos, i);
}
_matcher = matcher;
} else {
OrderedSloppyPhraseMatcher matcher;
for (size_t i = 0; i < terms.size(); i++) {
const auto& term = terms[i];
auto* term_pos = ensureTermPosition(term);
matcher._postings.emplace_back(term_pos, i);
}
matcher._allowed_slop = _slop;
_matcher = matcher;
}

std::sort(iterators.begin(), iterators.end(), [](const TermIterator& a, const TermIterator& b) {
Expand All @@ -89,13 +208,17 @@ void PhraseQuery::add(const std::wstring& field_name, const std::vector<std::str

_lead1 = iterators[0];
_lead2 = iterators[1];
for (int32_t i = 2; i < _terms.size(); i++) {
for (int32_t i = 2; i < iterators.size(); i++) {
_others.push_back(iterators[i]);
}
}

void PhraseQuery::search(roaring::Roaring& roaring) {
if (_slop <= 0) {
if (std::holds_alternative<PhraseQueryPtr>(_matcher)) {
_searcher->_search(
std::get<PhraseQueryPtr>(_matcher).get(),
[&roaring](const int32_t docid, const float_t /*score*/) { roaring.add(docid); });
} else {
if (_lead1.isEmpty()) {
return;
}
Expand All @@ -104,10 +227,6 @@ void PhraseQuery::search(roaring::Roaring& roaring) {
return;
}
search_by_skiplist(roaring);
} else {
_searcher->_search(_query.get(), [&roaring](const int32_t docid, const float_t /*score*/) {
roaring.add(docid);
});
}
}

Expand All @@ -125,8 +244,7 @@ void PhraseQuery::search_by_bitmap(roaring::Roaring& roaring) {
void PhraseQuery::search_by_skiplist(roaring::Roaring& roaring) {
int32_t doc = 0;
while ((doc = do_next(_lead1.nextDoc())) != INT32_MAX) {
reset();
if (next_match()) {
if (matches(doc)) {
roaring.add(doc);
}
}
Expand Down Expand Up @@ -169,67 +287,21 @@ int32_t PhraseQuery::do_next(int32_t doc) {
}
}

bool PhraseQuery::next_match() {
PostingsAndPosition& lead = _postings[0];
if (lead._upTo < lead._freq) {
lead._pos = lead._postings.nextPosition();
lead._upTo += 1;
} else {
return false;
}

while (true) {
int32_t phrasePos = lead._pos - lead._offset;

bool advance_head = false;
for (size_t j = 1; j < _postings.size(); ++j) {
PostingsAndPosition& posting = _postings[j];
int32_t expectedPos = phrasePos + posting._offset;
// advance up to the same position as the lead
if (!advance_position(posting, expectedPos)) {
return false;
}

if (posting._pos != expectedPos) { // we advanced too far
if (advance_position(lead, posting._pos - posting._offset + lead._offset)) {
advance_head = true;
break;
bool PhraseQuery::matches(int32_t doc) {
return std::visit(
[&doc](auto&& m) -> bool {
using T = std::decay_t<decltype(m)>;
if constexpr (std::is_same_v<T, PhraseQueryPtr>) {
_CLTHROWA(CL_ERR_IllegalArgument,
"PhraseQueryPtr does not support matches function");
} else {
return false;
return m.matches(doc);
}
}
}
if (advance_head) {
continue;
}

return true;
}

return false;
}

bool PhraseQuery::advance_position(PostingsAndPosition& posting, int32_t target) {
while (posting._pos < target) {
if (posting._upTo == posting._freq) {
return false;
} else {
posting._pos = posting._postings.nextPosition();
posting._upTo += 1;
}
}
return true;
}

void PhraseQuery::reset() {
for (PostingsAndPosition& posting : _postings) {
posting._freq = posting._postings.freq();
posting._pos = -1;
posting._upTo = 0;
}
},
_matcher);
}

Status PhraseQuery::parser_slop(std::string& query, InvertedIndexQueryInfo& query_info) {
void PhraseQuery::parser_slop(std::string& query, InvertedIndexQueryInfo& query_info) {
auto is_digits = [](const std::string_view& str) {
return std::all_of(str.begin(), str.end(), [](unsigned char c) { return std::isdigit(c); });
};
Expand All @@ -240,17 +312,38 @@ Status PhraseQuery::parser_slop(std::string& query, InvertedIndexQueryInfo& quer
if (tilde_pos < query.size() - 1 && query[tilde_pos] == '~') {
size_t slop_pos = tilde_pos + 1;
std::string_view slop_str(query.data() + slop_pos, query.size() - slop_pos);
if (is_digits(slop_str)) {
auto result = std::from_chars(slop_str.begin(), slop_str.end(), query_info.slop);
if (result.ec != std::errc()) {
return Status::Error<doris::ErrorCode::INVERTED_INDEX_INVALID_PARAMETERS>(
"PhraseQuery parser failed: {}", query);
do {
if (slop_str.empty()) {
break;
}
query = query.substr(0, last_space_pos);
}

bool ordered = false;
if (slop_str.size() == 1) {
if (!std::isdigit(slop_str[0])) {
break;
}
} else {
if (slop_str.back() == '+') {
ordered = true;
slop_str.remove_suffix(1);
}
}

if (is_digits(slop_str)) {
auto result =
std::from_chars(slop_str.begin(), slop_str.end(), query_info.slop);
if (result.ec != std::errc()) {
break;
}
query_info.ordered = ordered;
query = query.substr(0, last_space_pos);
}
} while (false);
}
}
return Status::OK();
}

template class PhraseMatcherBase<ExactPhraseMatcher>;
template class PhraseMatcherBase<OrderedSloppyPhraseMatcher>;

} // namespace doris::segment_v2
Loading

0 comments on commit b7dbd5c

Please sign in to comment.