Skip to content

Commit

Permalink
simple main
Browse files Browse the repository at this point in the history
  • Loading branch information
tomfran committed Dec 28, 2023
1 parent 0d4d6fa commit 1559704
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 2 deletions.
48 changes: 48 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// use search::index::Index;
use search::query::QueryProcessor;
use std::io::{self, Write};

const NUM_RESULTS: usize = 10;

fn print_results(results: &[u32]) {
println!("\nSearch Results:");
for (i, doc_id) in results.iter().enumerate() {
println!("\t- {:3}. Doc ID: {}", i + 1, doc_id);
}
println!();
}

fn read_line(prompt: &str) -> String {
print!("{}", prompt);
io::stdout().flush().unwrap();

let mut input = String::new();
io::stdin().read_line(&mut input).unwrap();
input.trim().to_string()
}

fn main() {
let base_path = "data/wiki-data";
let index_path = base_path.to_string() + "/index/index";
let tokenizer_path = base_path.to_string() + "/tokenizer/bert-base-uncased";

// let docs_path = base_path.to_string() + "/docs";
// Index::build_index(&docs_path, &index_path, &tokenizer_path);

println!(
"Search engine for base path: [{}]\nWrite a query and press enter.\n",
base_path
);

let mut q = QueryProcessor::build_query_processor(&index_path, &tokenizer_path);

loop {
let query = read_line("> ");

// Perform search
let results = q.query(&query, NUM_RESULTS);

// Display results
print_results(&results);
}
}
11 changes: 9 additions & 2 deletions src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,14 @@ impl QueryProcessor {
}
}

pub fn query(&mut self, query: &str) -> Vec<u32> {
pub fn query(&mut self, query: &str, num_results: usize) -> Vec<u32> {
let mut scores: HashMap<u32, f32> = HashMap::new();

println!(
"\t### tokenized query: {:?}",
self.index.tokenize_and_stem_query(query)
);

for token in self.index.tokenize_and_stem_query(query) {
if let Some(postings) = self.index.get_term(&token) {
let idf = (self.num_documents as f32 / postings.collection_frequency as f32).log2();
Expand All @@ -38,10 +43,12 @@ impl QueryProcessor {
.and_modify(|s| *s += doc_score)
.or_insert(doc_score);
}
} else {
println!("\t### no postings for term: {}", token);
}
}

let mut selector = DocumentSelector::new(3);
let mut selector = DocumentSelector::new(num_results);
scores
.iter()
.for_each(|(id, score)| selector.push(*id, *score));
Expand Down

0 comments on commit 1559704

Please sign in to comment.