From 6bd9ecadba3cae21124c49fe72be4446e58c809f Mon Sep 17 00:00:00 2001 From: Francesco Date: Sat, 27 Jan 2024 16:24:13 +0100 Subject: [PATCH] simplified structure, bm25 instead of tfidf --- README.md | 4 +- search/src/{index => engine}/builder.rs | 4 +- search/src/{index => engine}/documents.rs | 45 ++-- .../document_selector.rs => engine/heap.rs} | 48 ++-- search/src/engine/mod.rs | 213 ++++++++++++++++++ search/src/{index => engine}/postings.rs | 0 search/src/{index => engine}/preprocessor.rs | 0 search/src/{index => engine}/utils.rs | 0 search/src/{index => engine}/vocabulary.rs | 2 +- search/src/index/mod.rs | 107 --------- search/src/lib.rs | 3 +- search/src/main.rs | 11 +- search/src/query/mod.rs | 148 ------------ server/Cargo.toml | 1 + server/src/main.rs | 33 ++- server/templates/index.html | 2 +- 16 files changed, 293 insertions(+), 328 deletions(-) rename search/src/{index => engine}/builder.rs (97%) rename search/src/{index => engine}/documents.rs (79%) rename search/src/{query/document_selector.rs => engine/heap.rs} (56%) create mode 100644 search/src/engine/mod.rs rename search/src/{index => engine}/postings.rs (100%) rename search/src/{index => engine}/preprocessor.rs (100%) rename search/src/{index => engine}/utils.rs (100%) rename search/src/{index => engine}/vocabulary.rs (98%) delete mode 100644 search/src/index/mod.rs delete mode 100644 search/src/query/mod.rs diff --git a/README.md b/README.md index 37f399f..efa76ea 100644 --- a/README.md +++ b/README.md @@ -16,10 +16,10 @@ Search engine written in Rust, based on an inverted index on disk. **Index construction** - [x] In-memory datasets index construction; - [x] Proper vocabulary and paths on disk; -- [ ] Spelling correction index: in progress. +- [x] Spelling correction index;. **Queries** -- [x] Tf-idf ranked retrieval; +- [x] BM25 scoring; - [x] Window computation; **Evaluation** diff --git a/search/src/index/builder.rs b/search/src/engine/builder.rs similarity index 97% rename from search/src/index/builder.rs rename to search/src/engine/builder.rs index d42b8aa..971922f 100644 --- a/search/src/index/builder.rs +++ b/search/src/engine/builder.rs @@ -19,7 +19,7 @@ const PROGRESS_CHARS: &str = "=> "; const CUTOFF_THRESHOLD: f64 = 0.8; -pub fn build_index(input_dir: &str, output_path: &str, preprocessor: &Preprocessor) { +pub fn build_engine(input_dir: &str, output_path: &str, preprocessor: &Preprocessor) { let index: InMemory = build_in_memory(input_dir, preprocessor); Postings::write_postings(&index, output_path); Vocabulary::write_vocabulary(&index, output_path); @@ -59,7 +59,7 @@ fn build_in_memory(input_dir: &str, preprocessor: &Preprocessor) -> InMemory { // update documents array documents.lock().unwrap().push(Document { path: d.path().to_str().unwrap().to_string(), - lenght: tokens.len() as u32, + length: tokens.len() as u32, }); let mut l_term_index_map = term_index_map.lock().unwrap(); diff --git a/search/src/index/documents.rs b/search/src/engine/documents.rs similarity index 79% rename from search/src/index/documents.rs rename to search/src/engine/documents.rs index 36abb90..196f30e 100644 --- a/search/src/index/documents.rs +++ b/search/src/engine/documents.rs @@ -4,11 +4,12 @@ use crate::disk::{bits_reader::BitsReader, bits_writer::BitsWriter}; #[derive(Clone)] pub struct Document { pub path: String, - pub lenght: u32, + pub length: u32, } pub struct Documents { docs: Vec, + avg_len: f64, } impl Documents { @@ -16,21 +17,26 @@ impl Documents { let mut reader = BitsReader::new(&(input_path.to_string() + DOCUMENTS_EXTENSION)); let mut prev = String::new(); - let docs = (0..reader.read_vbyte()) + + let mut length_sum = 0; + + let docs: Vec = (0..reader.read_vbyte()) .map(|_| { let p_len = reader.read_gamma(); let prefix: String = prev.chars().take(p_len as usize).collect(); - let s = prefix + &reader.read_str(); - prev = s.clone(); + let path = prefix + &reader.read_str(); + prev = path.clone(); + + let length = reader.read_vbyte(); + length_sum += length; - Document { - path: s, - lenght: reader.read_vbyte(), - } + Document { path, length } }) .collect(); - Documents { docs } + let avg_len = length_sum as f64 / docs.len() as f64; + + Documents { docs, avg_len } } pub fn write_documents(documents: &Vec, output_path: &str) { @@ -47,7 +53,7 @@ impl Documents { prev = &l.path; writer.write_str(&remaining); - writer.write_vbyte(l.lenght); + writer.write_vbyte(l.length); } writer.flush(); @@ -58,7 +64,11 @@ impl Documents { } pub fn get_doc_len(&self, doc_id: u32) -> u32 { - self.docs[doc_id as usize].lenght + self.docs[doc_id as usize].length + } + + pub fn get_avg_doc_len(&self) -> f64 { + self.avg_len } pub fn get_doc_path(&self, doc_id: u32) -> String { @@ -79,11 +89,11 @@ mod tests { let documents = vec![ Document { path: "document1.txt".to_string(), - lenght: 100, + length: 100, }, Document { path: "document2.txt".to_string(), - lenght: 150, + length: 150, }, ]; @@ -94,7 +104,7 @@ mod tests { for (i, d) in documents.iter().enumerate() { assert_eq!(loaded_documents.get_doc_path(i as u32), d.path); - assert_eq!(loaded_documents.get_doc_len(i as u32), d.lenght); + assert_eq!(loaded_documents.get_doc_len(i as u32), d.length); } } @@ -103,23 +113,24 @@ mod tests { let documents = vec![ Document { path: "document1.txt".to_string(), - lenght: 100, + length: 100, }, Document { path: "document2.txt".to_string(), - lenght: 150, + length: 150, }, ]; let doc_collection = Documents { docs: documents.clone(), + avg_len: 125.0, }; assert_eq!(doc_collection.get_num_documents(), documents.len() as u32); for (i, d) in documents.iter().enumerate() { assert_eq!(doc_collection.get_doc_path(i as u32), d.path); - assert_eq!(doc_collection.get_doc_len(i as u32), d.lenght); + assert_eq!(doc_collection.get_doc_len(i as u32), d.length); } } } diff --git a/search/src/query/document_selector.rs b/search/src/engine/heap.rs similarity index 56% rename from search/src/query/document_selector.rs rename to search/src/engine/heap.rs index d5df92b..a01acc2 100644 --- a/search/src/query/document_selector.rs +++ b/search/src/engine/heap.rs @@ -3,12 +3,12 @@ use std::{cmp::Ordering, collections::BinaryHeap}; #[derive(Debug)] pub struct Entry { pub id: u32, - pub score: f64, + pub priority: f64, } impl PartialEq for Entry { fn eq(&self, other: &Self) -> bool { - self.score == other.score + self.priority == other.priority } } @@ -23,35 +23,41 @@ impl PartialOrd for Entry { impl Ord for Entry { fn cmp(&self, other: &Self) -> Ordering { other - .score - .partial_cmp(&self.score) + .priority + .partial_cmp(&self.priority) .unwrap_or(Ordering::Equal) } } -pub struct DocumentSelector { +pub struct FixedMaxHeap { heap: BinaryHeap, capacity: usize, } -impl DocumentSelector { - pub fn new(capacity: usize) -> DocumentSelector { - DocumentSelector { +impl FixedMaxHeap { + pub fn new(capacity: usize) -> FixedMaxHeap { + FixedMaxHeap { heap: BinaryHeap::new(), capacity, } } pub fn push(&mut self, id: u32, score: f64) { - self.heap.push(Entry { id, score }); + self.heap.push(Entry { + id, + priority: score, + }); if self.heap.len() > self.capacity { self.heap.pop(); } } - pub fn get_sorted_entries(&mut self) -> Vec { - let mut res: Vec = (0..self.capacity).filter_map(|_| self.heap.pop()).collect(); + pub fn get_sorted_id_priority_pairs(&mut self) -> Vec<(u32, f64)> { + let mut res: Vec<(u32, f64)> = (0..self.capacity) + .filter_map(|_| self.heap.pop().map(|e| (e.id, e.priority))) + .collect(); + res.reverse(); res } @@ -63,7 +69,7 @@ mod test { #[test] fn test_top_k() { - let mut selector = DocumentSelector::new(2); + let mut selector = FixedMaxHeap::new(2); selector.push(2, 0.4); selector.push(3, 0.3); @@ -71,29 +77,21 @@ mod test { selector.push(4, 0.2); assert_eq!( - selector - .get_sorted_entries() - .iter() - .map(|e| e.id) - .collect::>(), - [1, 2] + selector.get_sorted_id_priority_pairs(), + [(1, 0.5), (2, 0.4)] ); } #[test] fn test_top_less_than_k() { - let mut selector = DocumentSelector::new(3); + let mut selector = FixedMaxHeap::new(3); selector.push(1, 0.5); selector.push(2, 0.4); assert_eq!( - selector - .get_sorted_entries() - .iter() - .map(|e| e.id) - .collect::>(), - [1, 2] + selector.get_sorted_id_priority_pairs(), + [(1, 0.5), (2, 0.4)] ); } } diff --git a/search/src/engine/mod.rs b/search/src/engine/mod.rs new file mode 100644 index 0000000..c394adf --- /dev/null +++ b/search/src/engine/mod.rs @@ -0,0 +1,213 @@ +mod builder; +mod documents; +mod heap; +mod postings; +mod preprocessor; +mod utils; +mod vocabulary; + +use self::documents::{Document, Documents}; +use self::heap::FixedMaxHeap; +use self::postings::{PostingList, Postings}; +use self::preprocessor::Preprocessor; +use self::vocabulary::Vocabulary; +use std::cmp::min; +use std::collections::{BTreeMap, HashMap}; +use std::time::Instant; + +pub const POSTINGS_EXTENSION: &str = ".postings"; +pub const OFFSETS_EXTENSION: &str = ".offsets"; +pub const DOCUMENTS_EXTENSION: &str = ".docs"; +pub const VOCABULARY_ALPHA_EXTENSION: &str = ".alphas"; + +const WINDOW_SCORE_MULTIPLIER: f64 = 0.5; +const BM25_SCORE_MULTIPLIER: f64 = 1.0; + +const BM25_KL: f64 = 1.2; +const BM25_B: f64 = 0.75; + +pub struct Engine { + vocabulary: Vocabulary, + postings: Postings, + documents: Documents, + preprocessor: Preprocessor, +} + +pub struct InMemory { + term_index_map: BTreeMap, + postings: Vec, + documents: Vec, +} + +pub struct QueryResult { + pub tokens: Vec, + pub documents: Vec, + pub time_ms: u128, +} + +pub struct DocumentResult { + pub id: u32, + pub path: String, + pub score: f64, +} + +#[derive(Default)] +struct DocumentScore { + tf_idf: f64, + term_positions: HashMap>, +} + +impl Engine { + pub fn build_engine(input_path: &str, output_path: &str) { + builder::build_engine(input_path, output_path, &Preprocessor::new()); + } + + pub fn load_index(input_path: &str) -> Engine { + Engine { + vocabulary: Vocabulary::load_vocabulary(input_path), + postings: Postings::load_postings_reader(input_path), + documents: Documents::load_documents(input_path), + preprocessor: Preprocessor::new(), + } + } + + pub fn query(&mut self, query: &str, num_results: usize) -> QueryResult { + let start_time = Instant::now(); + + let tokens: Vec = self + .preprocessor + .tokenize_and_stem(query) + .iter() + .filter_map(|t| self.vocabulary.spellcheck_term(t)) + .collect(); + + let mut scores: HashMap = HashMap::new(); + + let n = self.documents.get_num_documents() as f64; + let avgdl = self.documents.get_avg_doc_len(); + + for (id, token) in tokens.iter().enumerate() { + if let Some(postings) = self.get_term_postings(token) { + // compute idf where n is the number of documents and + // nq the number of documents containing query term + + let nq = postings.collection_frequency as f64; + let idf = ((n - nq + 0.5) / (nq + 0.5) + 1.0).ln(); + + for doc_posting in &postings.documents { + // compute B25 score, where fq is the frequency of term in this documents + // dl is the document len, and avgdl is the average document len accross the collection + + let fq = doc_posting.document_frequency as f64; + let dl = self.documents.get_doc_len(doc_posting.document_id) as f64; + + let bm_score = idf * (fq * (BM25_KL + 1.0)) + / (fq + BM25_KL * (1.0 - BM25_B + BM25_B * (dl / avgdl))); + + let doc_score = scores.entry(doc_posting.document_id).or_default(); + doc_score.tf_idf += bm_score; + let positions = doc_score.term_positions.entry(id as u32).or_default(); + + doc_posting + .positions + .iter() + .for_each(|p| positions.push(*p)); + } + } + } + + let mut selector = FixedMaxHeap::new(num_results); + let num_tokens = tokens.len(); + for (id, score) in &mut scores { + score.tf_idf /= self.documents.get_doc_len(*id) as f64; + selector.push(*id, Self::compute_score(score, num_tokens)); + } + + let documents = selector + .get_sorted_id_priority_pairs() + .iter() + .map(|(id, score)| DocumentResult { + id: *id, + score: *score, + path: self.documents.get_doc_path(*id), + }) + .collect(); + + let time_ms = start_time.elapsed().as_millis(); + + QueryResult { + tokens, + documents, + time_ms, + } + } + + fn get_term_postings(&mut self, term: &str) -> Option { + self.vocabulary + .get_term_index(term) + .map(|i| self.postings.load_postings_list(i)) + } + + fn compute_score(document_score: &DocumentScore, num_tokens: usize) -> f64 { + let mut window = u32::MAX; + + let mut arr: Vec<(u32, u32)> = document_score + .term_positions + .iter() + .flat_map(|(id, positions)| positions.iter().map(|p| (*p, *id))) + .collect(); + + arr.sort_unstable(); + + let mut j = 0; + let mut seen: HashMap = HashMap::new(); + for (pos, id) in arr.iter().copied() { + seen.entry(id).and_modify(|c| *c += 1).or_insert(1); + + while seen.len() == num_tokens && j < arr.len() { + let (j_pos, j_id) = arr[j]; + window = min(window, pos - j_pos + 1); + + seen.entry(j_id).and_modify(|c| *c -= 1); + if *seen.get(&j_id).unwrap() == 0 { + seen.remove(&j_id); + } + + j += 1; + } + } + + WINDOW_SCORE_MULTIPLIER * (num_tokens as f64 / window as f64) + + BM25_SCORE_MULTIPLIER * document_score.tf_idf + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::test_utils::utils::create_temporary_dir_path; + + #[test] + fn test_build() { + let index_path = &create_temporary_dir_path(); + + Engine::build_engine("test_data/docs", index_path); + + let mut idx = Engine::load_index(index_path); + + for ele in ["hello", "man", "world"] { + assert!(idx.vocabulary.get_term_index(ele).is_some()); + } + + let mut query: Vec = idx + .query("hello", 10) + .documents + .iter() + .map(|d| d.path.clone()) + .collect(); + + query.sort(); + + assert_eq!(query, ["test_data/docs/1.txt", "test_data/docs/2.txt"]); + } +} diff --git a/search/src/index/postings.rs b/search/src/engine/postings.rs similarity index 100% rename from search/src/index/postings.rs rename to search/src/engine/postings.rs diff --git a/search/src/index/preprocessor.rs b/search/src/engine/preprocessor.rs similarity index 100% rename from search/src/index/preprocessor.rs rename to search/src/engine/preprocessor.rs diff --git a/search/src/index/utils.rs b/search/src/engine/utils.rs similarity index 100% rename from search/src/index/utils.rs rename to search/src/engine/utils.rs diff --git a/search/src/index/vocabulary.rs b/search/src/engine/vocabulary.rs similarity index 98% rename from search/src/index/vocabulary.rs rename to search/src/engine/vocabulary.rs index 400d635..b1accbf 100644 --- a/search/src/index/vocabulary.rs +++ b/search/src/engine/vocabulary.rs @@ -163,7 +163,7 @@ impl Vocabulary { mod tests { use std::collections::BTreeMap; - use crate::{index::postings::PostingList, test_utils::utils::create_temporary_file_path}; + use crate::{engine::postings::PostingList, test_utils::utils::create_temporary_file_path}; use super::*; diff --git a/search/src/index/mod.rs b/search/src/index/mod.rs deleted file mode 100644 index 25fbc48..0000000 --- a/search/src/index/mod.rs +++ /dev/null @@ -1,107 +0,0 @@ -mod builder; -mod documents; -mod postings; -mod preprocessor; -mod utils; -mod vocabulary; - -use self::documents::{Document, Documents}; -use self::postings::{PostingList, Postings}; -use self::preprocessor::Preprocessor; -use self::vocabulary::Vocabulary; -use std::collections::BTreeMap; - -pub const POSTINGS_EXTENSION: &str = ".postings"; -pub const OFFSETS_EXTENSION: &str = ".offsets"; -pub const DOCUMENTS_EXTENSION: &str = ".docs"; -pub const VOCABULARY_ALPHA_EXTENSION: &str = ".alphas"; - -pub struct Index { - vocabulary: Vocabulary, - postings: Postings, - documents: Documents, - preprocessor: Preprocessor, -} - -pub struct InMemory { - term_index_map: BTreeMap, - postings: Vec, - documents: Vec, -} - -impl Index { - pub fn build_index(input_path: &str, output_path: &str) { - builder::build_index(input_path, output_path, &Preprocessor::new()); - } - - pub fn load_index(input_path: &str) -> Index { - Index { - vocabulary: Vocabulary::load_vocabulary(input_path), - postings: Postings::load_postings_reader(input_path), - documents: Documents::load_documents(input_path), - preprocessor: Preprocessor::new(), - } - } - - pub fn get_term_postings(&mut self, term: &str) -> Option { - self.vocabulary - .get_term_index(term) - .map(|i| self.postings.load_postings_list(i)) - } - - pub fn get_query_tokens(&self, query: &str) -> Vec { - self.preprocessor.tokenize_and_stem(query) - } - - pub fn get_num_documents(&self) -> u32 { - self.documents.get_num_documents() - } - - pub fn get_document_len(&self, doc_id: u32) -> u32 { - self.documents.get_doc_len(doc_id) - } - - pub fn get_document_path(&self, doc_id: u32) -> String { - self.documents.get_doc_path(doc_id) - } - - pub fn spellcheck_term(&self, term: &str) -> Option { - self.vocabulary.spellcheck_term(term) - } -} - -#[cfg(test)] -mod test { - use super::*; - use crate::test_utils::utils::create_temporary_dir_path; - - #[test] - fn test_build() { - let index_path = &create_temporary_dir_path(); - - Index::build_index("test_data/docs", index_path); - - let mut idx = Index::load_index(index_path); - - for ele in ["hello", "man", "world"] { - assert!(idx.vocabulary.get_term_index(ele).is_some()); - } - - let pl = idx.get_term_postings("hello").unwrap(); - - let mut hello_docs = pl - .documents - .iter() - .map(|d| idx.get_document_path(d.document_id)) - .collect::>(); - - hello_docs.sort(); - - assert_eq!(hello_docs, ["test_data/docs/1.txt", "test_data/docs/2.txt"]); - - assert_eq!(pl.collection_frequency, 2); - - let pl = idx.get_term_postings("world").unwrap(); - assert_eq!(pl.documents[0].positions, [1]); - } -} diff --git a/search/src/lib.rs b/search/src/lib.rs index fee1643..c88bffb 100644 --- a/search/src/lib.rs +++ b/search/src/lib.rs @@ -1,4 +1,3 @@ pub mod disk; -pub mod index; -pub mod query; +pub mod engine; mod test_utils; diff --git a/search/src/main.rs b/search/src/main.rs index e9c2d32..2624685 100644 --- a/search/src/main.rs +++ b/search/src/main.rs @@ -1,6 +1,5 @@ use indicatif::HumanDuration; -use search::index::Index; -use search::query::{Processor, Result}; +use search::engine::{Engine, QueryResult}; use std::cmp::min; use std::env; use std::io::{self, Write}; @@ -10,7 +9,7 @@ use std::time::{Duration, Instant}; const NUM_TOP_RESULTS: usize = 10; const NUM_RESULTS: usize = 1_000_000; -fn print_results(result: &Result) { +fn print_results(result: &QueryResult) { println!("Search tokens: {:?}", result.tokens); if result.documents.is_empty() { @@ -85,7 +84,7 @@ fn main() { let start_time = Instant::now(); - Index::build_index(&docs_path, &index_path); + Engine::build_engine(&docs_path, &index_path); let elapsed_time = start_time.elapsed(); println!( "Index built in {}.\n\nLoad options:\n- CLI: cargo run --release --bin search {} load", @@ -96,7 +95,7 @@ fn main() { exit(0); } - let mut q = Processor::build_query_processor(&index_path); + let mut e = Engine::load_index(&index_path); println!( "Loaded search engine for directory: [{base_path}]\n\nWrite a query and press enter.\n" @@ -105,7 +104,7 @@ fn main() { loop { let query = read_line("> "); - let result = q.query(&query, NUM_RESULTS); + let result = e.query(&query, NUM_RESULTS); print_results(&result); } diff --git a/search/src/query/mod.rs b/search/src/query/mod.rs deleted file mode 100644 index 0cf925a..0000000 --- a/search/src/query/mod.rs +++ /dev/null @@ -1,148 +0,0 @@ -use std::{cmp::min, collections::HashMap, time::Instant}; - -use crate::index::Index; - -use self::document_selector::DocumentSelector; - -mod document_selector; - -const WINDOW_MULTIPLIER: f64 = 10.0; - -pub struct Processor { - index: Index, - num_documents: u32, -} - -pub struct Result { - pub tokens: Vec, - pub documents: Vec, - pub time_ms: u128, -} - -pub struct DocumentResult { - pub id: u32, - pub path: String, - pub score: f64, -} - -#[derive(Default)] -struct DocumentScore { - tf_idf: f64, - term_positions: HashMap>, -} - -impl Processor { - pub fn build_query_processor(index_input_path: &str) -> Processor { - let index = Index::load_index(index_input_path); - let num_documents = index.get_num_documents(); - - Processor { - index, - num_documents, - } - } - - pub fn query(&mut self, query: &str, num_results: usize) -> Result { - let start_time = Instant::now(); - - // spellcheck phase - let tokens: Vec = self - .index - .get_query_tokens(query) - .iter() - .filter_map(|t| self.index.spellcheck_term(t)) - .collect(); - - // retrieve documents - let documents = self - .get_sorted_document_entries(&tokens.clone(), num_results) - .iter() - .map(|e| DocumentResult { - id: e.id, - score: e.score, - path: self.index.get_document_path(e.id), - }) - .collect(); - - let time_ms = start_time.elapsed().as_millis(); - - Result { - tokens, - documents, - time_ms, - } - } - - fn get_sorted_document_entries( - &mut self, - tokens: &[String], - num_results: usize, - ) -> Vec { - let mut scores: HashMap = HashMap::new(); - - for (id, token) in tokens.iter().enumerate() { - if let Some(postings) = self.index.get_term_postings(token) { - let idf = (self.num_documents as f64 / postings.collection_frequency as f64).log2(); - - // for each term-doc pair, increment the documetn tf-idf score - // and record token positions for window computation - for doc_posting in &postings.documents { - let td_idf_score = doc_posting.document_frequency as f64 * idf; - - let doc_score = scores.entry(doc_posting.document_id).or_default(); - - doc_score.tf_idf += td_idf_score; - let positions = doc_score.term_positions.entry(id as u32).or_default(); - - doc_posting - .positions - .iter() - .for_each(|p| positions.push(*p)); - } - } - } - - let mut selector = DocumentSelector::new(num_results); - let num_tokens = tokens.len(); - for (id, score) in &mut scores { - // tf-idf score must be divided by the document len - score.tf_idf /= self.index.get_document_len(*id) as f64; - selector.push(*id, Processor::compute_score(score, num_tokens)); - } - - selector.get_sorted_entries() - } - - // score takes into consideration the window size and td-idf scoring - fn compute_score(document_score: &DocumentScore, num_tokens: usize) -> f64 { - let mut window = u32::MAX; - - let mut arr: Vec<(u32, u32)> = document_score - .term_positions - .iter() - .flat_map(|(id, positions)| positions.iter().map(|p| (*p, *id))) - .collect(); - - arr.sort_unstable(); - - let mut j = 0; - let mut seen: HashMap = HashMap::new(); - for (pos, id) in arr.iter().copied() { - seen.entry(id).and_modify(|c| *c += 1).or_insert(1); - - while seen.len() == num_tokens && j < arr.len() { - let (j_pos, j_id) = arr[j]; - window = min(window, pos - j_pos + 1); - - seen.entry(j_id).and_modify(|c| *c -= 1); - if *seen.get(&j_id).unwrap() == 0 { - seen.remove(&j_id); - } - - j += 1; - } - } - - WINDOW_MULTIPLIER * (num_tokens as f64 / window as f64) + document_score.tf_idf - } -} diff --git a/server/Cargo.toml b/server/Cargo.toml index 03017f3..6da40af 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -13,3 +13,4 @@ log = "0.4.20" search = { path = "../search" } serde = { version = "1.0.195", features = ["derive"] } tokio = { version = "1.35.1", features = ["macros", "rt-multi-thread"] } +lru = "0.12.1" diff --git a/server/src/main.rs b/server/src/main.rs index d12f3a2..0827158 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -8,20 +8,22 @@ use axum::{ Form, Router, }; use log::info; -use search::query::Processor; +use lru::LruCache; +use search::engine::Engine; use serde::{Deserialize, Serialize}; use std::{ env, fs::read_to_string, - mem::replace, + num::NonZeroUsize, sync::{Arc, Mutex}, }; +const CACHE_SIZE: usize = 10; + struct AppState { index_path: String, - query_processor: Mutex, - cached_query: Mutex, - cached_result: Mutex>, + engine: Mutex, + query_cache: Mutex>, } #[tokio::main] @@ -42,9 +44,8 @@ async fn main() { let state = Arc::new(AppState { index_path: base_path.clone(), - query_processor: Mutex::new(Processor::build_query_processor(&index_path)), - cached_query: Mutex::new(String::new()), - cached_result: Mutex::new(None), + engine: Mutex::new(Engine::load_index(&index_path)), + query_cache: Mutex::new(LruCache::new(NonZeroUsize::new(CACHE_SIZE).unwrap())), }); let app = Router::new() @@ -119,17 +120,16 @@ async fn post_query( ) -> impl IntoResponse { info!("Query request: {}", payload.query); - let mut cq = state.cached_query.lock().unwrap(); - let mut cqr = state.cached_result.lock().unwrap(); + let mut query_cache = state.query_cache.lock().unwrap(); - if *cq == payload.query { + if let Some(cached_result) = query_cache.get(&payload.query) { info!("Cache hit for query: {}", payload.query); - return HtmlTemplate(cqr.clone().unwrap()); + return HtmlTemplate(cached_result.clone()); } - let mut q = state.query_processor.lock().unwrap(); + let mut engine = state.engine.lock().unwrap(); - let query_result = q.query(&payload.query, 100); + let query_result = engine.query(&payload.query, 100); let documents = query_result .documents @@ -148,9 +148,8 @@ async fn post_query( time_ms: query_result.time_ms, }; - info!("Replacing cache with query: {}", payload.query); - let _ = replace(&mut *cq, payload.query); - let _ = replace(&mut *cqr, Some(response.clone())); + info!("Caching query: {}", payload.query); + query_cache.put(payload.query.clone(), response.clone()); HtmlTemplate(response) } diff --git a/server/templates/index.html b/server/templates/index.html index 15036e1..f71bd64 100644 --- a/server/templates/index.html +++ b/server/templates/index.html @@ -65,7 +65,7 @@ -
+