Skip to content

Commit

Permalink
spell check before query
Browse files Browse the repository at this point in the history
  • Loading branch information
tomfran committed Jan 26, 2024
1 parent 7f29302 commit 7aca2ca
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 16 deletions.
6 changes: 5 additions & 1 deletion search/src/index/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl Index {
}
}

pub fn get_term_postings(&mut self, term: &str) -> Option<postings::PostingList> {
pub fn get_term_postings(&mut self, term: &str) -> Option<PostingList> {
self.vocabulary
.get_term_index(term)
.map(|i| self.postings.load_postings_list(i))
Expand All @@ -64,6 +64,10 @@ impl Index {
pub fn get_document_path(&self, doc_id: u32) -> String {
self.documents.get_doc_path(doc_id)
}

pub fn spellcheck_term(&self, term: &str) -> Option<String> {
self.vocabulary.spellcheck_term(term)
}
}

#[cfg(test)]
Expand Down
69 changes: 59 additions & 10 deletions search/src/index/vocabulary.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::cmp::min;

use super::{utils, InMemory, VOCABULARY_ALPHA_EXTENSION};
use crate::disk::{bits_reader::BitsReader, bits_writer::BitsWriter};
use fxhash::FxHashMap;
Expand Down Expand Up @@ -98,28 +100,62 @@ impl Vocabulary {
self.term_to_index.get(term).copied()
}

#[allow(dead_code)]

pub fn get_term_index_spellcheck(&self, term: &str) -> Option<usize> {
self.get_term_index(term)
.or_else(|| self.get_closest_index(term))
pub fn spellcheck_term(&self, term: &str) -> Option<String> {
if self.term_to_index.contains_key(term) {
Some(term.to_string())
} else {
self.get_closest_index(term)
.and_then(|i| self.index_to_term.get(i).cloned())
}
}
#[allow(dead_code)]

fn get_closest_index(&self, term: &str) -> Option<usize> {
let candidates = (0..term.len() - 2)
.map(|i| term[i..i + 3].to_string())
.filter_map(|t| self.trigram_index.get(&t))
.flat_map(|v| v.iter());

// find lowest levenshtein distance with maximum frequency
candidates
.min_by_key(|i| Self::distance(term, &self.index_to_term[**i]))
.min_by_key(|i| {
(
Self::levenshtein_distance(term, &self.index_to_term[**i]),
-(self.frequencies[**i] as i32),
)
})
.copied()
}

#[allow(unused_variables)]
fn distance(s1: &str, s2: &str) -> u32 {
todo!()
fn levenshtein_distance(s1: &str, s2: &str) -> usize {
if s1.len() > s2.len() {
return Self::levenshtein_distance(s2, s1);
}

let n = s1.len() + 1;
let m = s2.len() + 1;

if n == 0 {
return m;
}

let mut dp = vec![vec![0; m]; n];

for i in 0..m {
dp[0][i] = i;
}

for (i, c1) in s1.chars().enumerate() {
dp[i][0] = i;
for (j, c2) in s2.chars().enumerate() {
if c1 == c2 {
dp[i + 1][j + 1] = dp[i][j];
} else {
dp[i + 1][j + 1] = 1 + min(dp[i][j], min(dp[i + 1][j], dp[i][j + 1]));
}
}
}

dp[n - 1][m - 1]
}
}

Expand Down Expand Up @@ -165,5 +201,18 @@ mod tests {
assert_eq!(*loaded_vocabulary.trigram_index.get("hel").unwrap(), [0]);
assert_eq!(*loaded_vocabulary.trigram_index.get("ell").unwrap(), [0]);
assert_eq!(*loaded_vocabulary.trigram_index.get("rld").unwrap(), [1]);

assert_eq!(loaded_vocabulary.spellcheck_term("hell").unwrap(), "hello");
assert_eq!(loaded_vocabulary.spellcheck_term("wrld").unwrap(), "world");
assert_eq!(loaded_vocabulary.spellcheck_term("he"), None);
}

#[test]
fn test_levenshtein_distance() {
assert_eq!(Vocabulary::levenshtein_distance("hello", "hello"), 0);
assert_eq!(Vocabulary::levenshtein_distance("hello", ""), 5);
assert_eq!(Vocabulary::levenshtein_distance("", ""), 0);
assert_eq!(Vocabulary::levenshtein_distance("cat", "cats"), 1);
assert_eq!(Vocabulary::levenshtein_distance("abc", "xyz"), 3);
}
}
11 changes: 9 additions & 2 deletions search/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,15 @@ impl Processor {
pub fn query(&mut self, query: &str, num_results: usize) -> Result {
let start_time = Instant::now();

let tokens = self.index.get_query_tokens(query);
// spellcheck phase
let tokens: Vec<String> = self
.index
.get_query_tokens(query)
.iter()
.filter_map(|t| self.index.spellcheck_term(t))
.collect();

// retrieve documents
let documents = self
.get_sorted_document_entries(&tokens.clone(), num_results)
.iter()
Expand Down Expand Up @@ -97,7 +104,7 @@ impl Processor {

let mut selector = DocumentSelector::new(num_results);
let num_tokens = tokens.len();
for (id, score) in scores.iter_mut() {
for (id, score) in &mut scores {
// tf-idf score must be divided by the document len
score.tf_idf /= self.index.get_document_len(*id) as f64;
selector.push(*id, Processor::compute_score(score, num_tokens));
Expand Down
4 changes: 2 additions & 2 deletions server/templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
<title>search-rs</title>
</head>

<body class=" dark bg-neutral-100 dark:bg-neutral-900 text-neutral-700 dark:text-white">
<body class=" dark white dark:bg-neutral-900 text-neutral-900 dark:text-white">

<!-- Main Content -->
<div class="container mx-auto mt-16 max-w-3xl">
Expand All @@ -67,7 +67,7 @@
<div class="mb-6">
<h1 class="text-3xl font-medium mb-10">Index on {{index_path}}</h1>
<input type="text"
class="outline-neutral-300 dark:outline-neutral-900 w-full p-4 rounded-md dark:bg-neutral-200 dark:bg-neutral-800"
class="outline-neutral-300 dark:outline-neutral-900 w-full p-4 rounded-md bg-neutral-100 dark:bg-neutral-800"
placeholder="Enter your search query..." autofocus name="query" hx-post="/query" hx-ext='json-enc'
hx-target=".search-results" hx-trigger="keyup[keyCode==13]">
</div>
Expand Down
2 changes: 1 addition & 1 deletion server/templates/query.html
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ <h1 class=" font-light text-md mb-6">
{% for doc in documents %}


<div id="{{doc.path}}" class="toggle-container bg-white dark:bg-neutral-800 p-6 rounded-md mb-6">
<div id="{{doc.path}}" class="toggle-container bg-neutral-100 dark:bg-neutral-800 p-6 rounded-md mb-6">
<div id="{{doc.path}}_closed">
<h2 class="text-xl font-semibold mb-4">
{{ doc.path }}
Expand Down

0 comments on commit 7aca2ca

Please sign in to comment.