Skip to content

Commit

Permalink
Everything working but printing on the lua side
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewmturner committed Jan 6, 2025
1 parent fd34d1c commit 14406e8
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 40 deletions.
Binary file modified artifacts/libffi.dylib
Binary file not shown.
8 changes: 5 additions & 3 deletions benches/benches/build_index.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use std::ffi::c_char;

use criterion::{criterion_group, criterion_main, Criterion};

extern "C" fn cb(_progress: f64) {}
extern "C" fn cb(_msg: *const c_char) {}

fn build_index() {
let mut index = rfsee_tf_idf::TfIdf::default();
index.par_load_rfcs(cb, cb).unwrap();
index.finish();
index.par_load_rfcs(cb).unwrap();
index.finish(cb);
let path = std::path::PathBuf::from("/tmp/bench_index.json");
index.save(&path)
}
Expand Down
25 changes: 17 additions & 8 deletions crates/cli/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
use std::{fs::File, path::PathBuf, time::Instant};
use std::{
ffi::{c_char, CStr},
fs::File,
path::PathBuf,
time::Instant,
};

use clap::{Parser, Subcommand};
use rfsee_tf_idf::{
Expand Down Expand Up @@ -27,12 +32,15 @@ enum Command {
},
}

extern "C" fn fetch_progress_cb(progress: f64) {
println!("Fetching RFCs progress: {progress:.2}%")
}
extern "C" fn print_c_char(ptr: *const c_char) {
if ptr.is_null() {
return;
}

extern "C" fn parse_progress_cb(progress: f64) {
println!("Parsing RFCs progress: {progress:.2}%")
let msg = unsafe { CStr::from_ptr(ptr) };
if let Ok(msg) = msg.to_str() {
println!("{msg}")
}
}

fn handle_command(args: Args) -> RFSeeResult<()> {
Expand All @@ -42,13 +50,14 @@ fn handle_command(args: Args) -> RFSeeResult<()> {
println!("Indexing RFCs");
let start = Instant::now();
let mut index = TfIdf::default();
index.par_load_rfcs(fetch_progress_cb, parse_progress_cb)?;
index.par_load_rfcs(print_c_char)?;
println!("Loading RFCs took {:?}", start.elapsed());
let building_index_start = Instant::now();
index.finish();
index.finish(print_c_char);
println!("Building index took {:?}", building_index_start.elapsed());
let saving_start = Instant::now();
let index_path = get_index_path(path)?;
println!("Saving index");
index.save(&index_path);
println!("Saving index took {:?}", saving_start.elapsed());
}
Expand Down
11 changes: 3 additions & 8 deletions crates/ffi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,11 @@ struct RfcSearchResultsContainer {
}

#[no_mangle]
pub extern "C" fn build_index(
fetch_progress_cb: extern "C" fn(progress: f64),
parse_progress_cb: extern "C" fn(progress: f64),
) {
pub extern "C" fn build_index(progress_cb: extern "C" fn(msg: *const c_char)) {
let path = rfsee_tf_idf::get_index_path(None).unwrap();
let mut index = rfsee_tf_idf::TfIdf::default();
index
.par_load_rfcs(fetch_progress_cb, parse_progress_cb)
.unwrap();
index.finish();
index.par_load_rfcs(progress_cb).unwrap();
index.finish(progress_cb);
index.save(&path);
}

Expand Down
45 changes: 34 additions & 11 deletions crates/tf_idf/src/index.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{
collections::HashMap,
ffi::{c_char, CString},
path::{Path, PathBuf},
sync::{Arc, Mutex},
time::Duration,
Expand Down Expand Up @@ -142,8 +143,7 @@ impl TfIdf {
/// Load the RFCs in parallel using a threadpool
pub fn par_load_rfcs(
&mut self,
fetch_progress_cb: extern "C" fn(progress: f64),
parse_progress_cb: extern "C" fn(progress: f64),
progress_cb: extern "C" fn(progress: *const c_char),
) -> RFSeeResult<()> {
let pool = threadpool::ThreadPool::new(12);
let raw_rfc_index = fetch_rfc_index()?;
Expand All @@ -164,11 +164,13 @@ impl TfIdf {
if let Ok(r) = fetch_rfc(&string) {
let mut guard = parsed_rfcs.lock().unwrap();
guard.push(r);
let processed = guard.len();
if processed % 100 == 0 {
let progress = (processed as f64 / rfcs_count as f64) * 100_f64;
fetch_progress_cb(progress)
}
// let processed = guard.len();
// if processed % 100 == 0 {
// let progress = (processed as f64 / rfcs_count as f64) * 100_f64;
// if let Ok(msg) = CString::new(format!("Fetch progress: {progress:0.0}%")) {
// progress_cb(msg.as_ptr())
// }
// }
};
let mut guard = remaining.lock().unwrap();
*guard -= 1;
Expand All @@ -179,10 +181,14 @@ impl TfIdf {
while !finished {
let remaining = remaining.clone();
let guard = remaining.lock().unwrap();
if let Ok(msg) = CString::new(format!("{} remaining RFCs to fetch", *guard)) {
progress_cb(msg.into_raw())
}
if *guard == 0 {
finished = true
} else {
drop(guard);
// Don't want to go crazy locking the Mutex, so we only check every 5 seconds
std::thread::sleep(Duration::from_secs(5));
}
}
Expand All @@ -194,7 +200,11 @@ impl TfIdf {
self.add_rfc_entry(rfc);
if i % 100 == 0 {
let progress = (i as f64 / rfcs_count as f64) * 100_f64;
parse_progress_cb(progress)
// if let Ok(msg) =
// CString::new(format!("Parse progress: {progress:0.0}%"))
// {
// progress_cb(msg.as_ptr())
// }
}
}
}
Expand Down Expand Up @@ -248,7 +258,10 @@ impl TfIdf {

/// Take all the processed documents and their term frequencies to compute the final term
/// scores
pub fn finish(&mut self) {
pub fn finish(&mut self, progress_cb: extern "C" fn(*const c_char)) {
if let Ok(msg) = CString::new("Collecting terms") {
progress_cb(msg.as_ptr())
}
// First, we collect all terms and the number of docs they appear in
let mut term_counts: HashMap<&String, usize> = HashMap::new();
for indexed_rfc in self.processed_rfcs.values() {
Expand All @@ -261,6 +274,9 @@ impl TfIdf {
}
}

if let Ok(msg) = CString::new("Computing inverse document frequencies") {
progress_cb(msg.as_ptr())
}
// Then we compute the inverse document frequency for each term
let total_docs = self.processed_rfcs.len();
for (term, docs_with_term) in term_counts {
Expand All @@ -269,6 +285,9 @@ impl TfIdf {
self.idfs.insert(term.clone(), scaled);
}

if let Ok(msg) = CString::new("Scoring documents") {
progress_cb(msg.as_ptr())
}
// Then we compute the score for each term in all documents
self.processed_rfcs.iter().for_each(|(_doc, rfc)| {
for (doc_term, freq) in &rfc.term_freqs {
Expand Down Expand Up @@ -357,8 +376,12 @@ pub fn search_index(search: String, index: Index) -> Vec<RfcSearchResult> {

#[cfg(test)]
mod tests {
use std::ffi::c_char;

use super::{parse_rfc_index, RfcEntry, TfIdf};

extern "C" fn dummy_cb(_msg: *const c_char) {}

#[test]
fn test_parse_index() {
let index_contents = std::fs::read_to_string("../../data/rfc_index.txt").unwrap();
Expand All @@ -377,7 +400,7 @@ mod tests {
url: "https://www.rfsee.com/1".to_string(),
};
tf_idf.add_rfc_entry(entry);
tf_idf.finish();
tf_idf.finish(dummy_cb);

assert_eq!(tf_idf.index.rfc_details.len(), 1);
assert_eq!(tf_idf.index.term_scores.len(), 2);
Expand All @@ -398,7 +421,7 @@ mod tests {
url: "https://www.rfsee.com/1".to_string(),
};
tf_idf.add_rfc_entry(entry);
tf_idf.finish();
tf_idf.finish(dummy_cb);

assert_eq!(tf_idf.index.rfc_details.len(), 1);
// This should be 1 once we update parsing to treat "Hello" and "hello" the same
Expand Down
2 changes: 1 addition & 1 deletion lua/rfsee/ffi.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ local ffi = require("ffi")

ffi.cdef([[
typedef void (*progress_callback_t)(double progress);
void build_index(progress_callback_t fetch_cb, progress_callback_t parse_cb);
void build_index(progress_callback_t progress_cb);
struct RfcSearchResult {
const char* url;
Expand Down
50 changes: 42 additions & 8 deletions lua/rfsee/index.lua
Original file line number Diff line number Diff line change
Expand Up @@ -66,21 +66,55 @@ function M.refresh()
local buf, win = window.create_progress_window()
window.update_progress_window(buf, "Building RFC index")

local function fetch_progress_cb(pct)
local msg = string.format("Downloading RFCs progress: %.1f%%", pct)
-- local function progress_cb(ptr)
-- -- local msg = ffi.string(ptr)
-- -- -- local msg = string.format("Downloading RFCs progress: %.1f%%", pct)
-- -- window.update_progress_window(buf, msg)
--
-- local ok, err = xpcall(function()
-- print("Getting string")
-- io.stdout:flush()
-- local msg = ffi.string(ptr) -- May throw if `ptr` is invalid
-- print("Got string")
-- io.stdout:flush()
-- window.update_progress_window(buf, msg)
-- end, debug.traceback)
-- print("Ok: ", ok)
-- io.stdout:flush()
-- if not ok then
-- -- Log the error, but don't let it unwind into Rust
-- print("Error in progress_cb:", err)
-- vim.cmd("redraw")
-- end
-- end

local function real_progress_cb(ptr)
print("Inside real_progress_cb, about to ffi.string(ptr)")
io.stdout:flush()

local msg = ffi.string(ptr) -- If this fails, it won't kill the process if wrapped by pcall
print("Successfully got msg:", msg)
io.stdout:flush()

window.update_progress_window(buf, msg)
end

local function safe_progress_cb(ptr)
local ok, err = xpcall(function()
real_progress_cb(ptr)
end, debug.traceback)

local function parse_progress_cb(pct)
local msg = string.format("Parsing RFCs progress: %.1f%%", pct)
window.update_progress_window(buf, msg)
if not ok then
print("Lua callback error:", err)
io.stdout:flush()
-- DO NOT re-throw the error; otherwise it unwinds back into Rust
end
end

local fetch_progress_cb_c = ffi.cast("progress_callback_t", fetch_progress_cb)
local parse_progress_cb_c = ffi.cast("progress_callback_t", parse_progress_cb)
local progress_cb_c = ffi.cast("progress_callback_t", safe_progress_cb)

lib.build_index(fetch_progress_cb_c, parse_progress_cb_c)
-- M.progress_cb_c = ffi.cast("progress_callback_t", progress_cb)
lib.build_index(progress_cb_c)
local end_time = os.clock()
window.update_progress_window(buf, string.format("Built RFC index", end_time - start_time))
-- Brief pause before closing
Expand Down
6 changes: 5 additions & 1 deletion tests/generate-data/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use std::ffi::c_char;

use rfsee_tf_idf::{RfcEntry, TfIdf};

extern "C" fn dummy_cb(_msg: *const c_char) {}

fn main() {
let mut tf_idf = TfIdf::default();
let rfc1 = RfcEntry {
Expand All @@ -18,7 +22,7 @@ fn main() {
tf_idf.add_rfc_entry(rfc1);
tf_idf.add_rfc_entry(rfc2);

tf_idf.finish();
tf_idf.finish(dummy_cb);
let path = rfsee_tf_idf::get_index_path(None).unwrap();
tf_idf.save(&path);
}

0 comments on commit 14406e8

Please sign in to comment.