From e68f261a091f209c757a623f0e903360e5eca2a8 Mon Sep 17 00:00:00 2001 From: Alex Garcia Date: Sun, 3 Nov 2024 22:50:42 -0800 Subject: [PATCH] initial pass --- src/batch.rs | 220 +++++++++++++++++++++++++++++++++++++++++++++++++ src/clients.rs | 50 +++++++++++ src/lib.rs | 7 +- test.sql | 66 +++++++++++++++ 4 files changed, 341 insertions(+), 2 deletions(-) create mode 100644 src/batch.rs diff --git a/src/batch.rs b/src/batch.rs new file mode 100644 index 0000000..47ce70a --- /dev/null +++ b/src/batch.rs @@ -0,0 +1,220 @@ +use crate::{Client, FLOAT32_VECTOR_SUBTYPE}; +use sqlite_loadable::{ + api, + table::{ConstraintOperator, IndexInfo, VTab, VTabArguments, VTabCursor}, + BestIndexError, Result, +}; +use sqlite_loadable::{prelude::*, Error}; +use std::{cell::RefCell, collections::HashMap, marker::PhantomData, mem, os::raw::c_int, rc::Rc}; +use zerocopy::AsBytes; + +static CREATE_SQL: &str = + "CREATE TABLE x(contents, embedding, input1 hidden, input2 hidden, source hidden)"; +enum Columns { + Contents, + Embedding, + Input1, + Input2, + Source, +} +fn column(index: i32) -> Option { + match index { + 0 => Some(Columns::Contents), + 1 => Some(Columns::Embedding), + 2 => Some(Columns::Input1), + 3 => Some(Columns::Input2), + 4 => Some(Columns::Source), + _ => None, + } +} + +#[repr(C)] +pub struct BatchTable { + /// must be first + base: sqlite3_vtab, + clients: Rc>>, +} + +impl<'vtab> VTab<'vtab> for BatchTable { + type Aux = Rc>>; + type Cursor = BatchCursor<'vtab>; + + fn connect( + _db: *mut sqlite3, + aux: Option<&Self::Aux>, + _args: VTabArguments, + ) -> Result<(String, BatchTable)> { + let base: sqlite3_vtab = unsafe { mem::zeroed() }; + let vtab = BatchTable { + base, + clients: aux.unwrap().clone(), + }; + // TODO db.config(VTabConfig::Innocuous)?; + Ok((CREATE_SQL.to_owned(), vtab)) + } + fn destroy(&self) -> Result<()> { + Ok(()) + } + + fn best_index(&self, mut info: IndexInfo) -> core::result::Result<(), BestIndexError> { + let mut has_input1 = false; + let mut has_input2 = false; + for mut constraint in info.constraints() { + match column(constraint.column_idx()) { + Some(Columns::Input1) => { + if constraint.usable() && constraint.op() == Some(ConstraintOperator::EQ) { + constraint.set_omit(true); + constraint.set_argv_index(1); + has_input1 = true; + } else { + return Err(BestIndexError::Constraint); + } + } + Some(Columns::Input2) => { + if constraint.usable() && constraint.op() == Some(ConstraintOperator::EQ) { + constraint.set_omit(true); + constraint.set_argv_index(2); + has_input2 = true; + } else { + return Err(BestIndexError::Constraint); + } + } + _ => (), + } + } + if !has_input1 { + return Err(BestIndexError::Error); + } + info.set_estimated_cost(100000.0); + info.set_estimated_rows(100000); + info.set_idxnum(2); + + Ok(()) + } + + fn open(&mut self) -> Result> { + Ok(BatchCursor::new(self.clients.clone())) + } +} + +type Entry = (serde_json::Value, Vec); +#[repr(C)] +pub struct BatchCursor<'vtab> { + /// Base class. Must be first + base: sqlite3_vtab_cursor, + clients: Rc>>, + results: Option>, + curr: usize, + phantom: PhantomData<&'vtab BatchTable>, +} +impl BatchCursor<'_> { + fn new<'vtab>(clients: Rc>>) -> BatchCursor<'vtab> { + let base: sqlite3_vtab_cursor = unsafe { mem::zeroed() }; + BatchCursor { + base, + clients: clients, + results: None, + curr: 0, + phantom: PhantomData, + } + } +} + +impl VTabCursor for BatchCursor<'_> { + fn filter( + &mut self, + _idx_num: c_int, + _idx_str: Option<&str>, + values: &[*mut sqlite3_value], + ) -> Result<()> { + self.curr = 0; + let first = values.get(0).unwrap(); + let (client_name, input) = match values.get(1) { + Some(v) => (api::value_text(first).unwrap(), api::value_text(v).unwrap()), + None => ("default", api::value_text(first).unwrap()), + }; + + let x = self.clients.borrow(); + let client = x.get(client_name).ok_or_else(|| { + Error::new_message(format!( + "Client with name {client_name} was not registered with rembed_clients." + )) + })?; + + let input: serde_json::Value = serde_json::from_str(input).unwrap(); + let input = input.as_array().unwrap(); + let x: Vec = input + .iter() + .map(|v| { + let contents = v.get("contents").unwrap().as_str().unwrap(); + contents.to_string() + }) + .collect(); + let embeddings = match client { + Client::Ollama(c) => c.infer_multiple(x).unwrap(), + _ => todo!(), + }; + self.results = Some( + embeddings + .iter() + .zip(input) + .map(|(emb, val)| (val.to_owned(), emb.to_owned())) + .collect(), + ); + + Ok(()) + } + + fn next(&mut self) -> Result<()> { + self.curr += 1; + Ok(()) + } + + fn eof(&self) -> bool { + self.results + .as_ref() + .map_or(true, |v| v.get(self.curr).is_none()) + } + + fn column(&self, context: *mut sqlite3_context, i: c_int) -> Result<()> { + match column(i) { + Some(Columns::Contents) => { + api::result_text( + context, + self.results + .as_ref() + .unwrap() + .get(self.curr) + .unwrap() + .0 + .get("contents") + .unwrap() + .as_str() + .unwrap(), + )?; + } + Some(Columns::Embedding) => { + api::result_blob( + context, + self.results + .as_ref() + .unwrap() + .get(self.curr) + .unwrap() + .1 + .as_bytes(), + ); + api::result_subtype(context, FLOAT32_VECTOR_SUBTYPE); + } + Some(Columns::Input1) => todo!(), + Some(Columns::Input2) => todo!(), + Some(Columns::Source) => todo!(), + None => todo!(), + } + Ok(()) + } + + fn rowid(&self) -> Result { + Ok(self.curr as i64) + } +} diff --git a/src/clients.rs b/src/clients.rs index 5f83b9a..308387b 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -447,6 +447,26 @@ impl OllamaClient { })?; OllamaClient::parse_single_response(data) } + pub fn infer_multiple(&self, input: Vec) -> Result>> { + let mut body = serde_json::Map::new(); + body.insert("model".to_owned(), self.model.to_owned().into()); + + body.insert("input".to_owned(), input.into()); + let body = serde_json::to_vec(&body).map_err(|error| { + Error::new_message(format!("Error serializing body to JSON: {error}")) + })?; + + let data: serde_json::Value = ureq::post("http://localhost:11434/v1/embeddings") + .set("Content-Type", "application/json") + .send_bytes(body.as_ref()) + .map_err(|error| Error::new_message(format!("Error sending HTTP request: {error}")))? + .into_json() + .map_err(|error| { + Error::new_message(format!("Error parsing HTTP response as JSON: {error}")) + })?; + OllamaClient::parse_multiple_response(data) + } + pub fn parse_single_response(value: serde_json::Value) -> Result> { value .get("embedding") @@ -467,6 +487,36 @@ impl OllamaClient { .collect() }) } + pub fn parse_multiple_response(value: serde_json::Value) -> Result>> { + let data = value + .get("data") + .ok_or_else(|| Error::new_message("expected 'data' key in response body")) + .and_then(|v| { + v.as_array() + .ok_or_else(|| Error::new_message("expected 'data' path to be an array")) + }) + .unwrap(); + + data.iter() + .map(|v| { + let embedding_object = v.as_object().unwrap(); + embedding_object + .get("embedding") + .unwrap() + .as_array() + .unwrap() + .iter() + .map(|v| { + v.as_f64() + .ok_or_else(|| { + Error::new_message("expected 'embedding' array to contain floats") + }) + .map(|f| f as f32) + }) + .collect() + }) + .collect() + } } #[derive(Clone)] diff --git a/src/lib.rs b/src/lib.rs index 1924525..3ab5ab6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +mod batch; mod clients; mod clients_vtab; @@ -5,11 +6,12 @@ use std::cell::RefCell; use std::collections::HashMap; use std::rc::Rc; +use batch::BatchTable; use clients::{Client, CohereClient, LlamafileClient, NomicClient, OllamaClient, OpenAiClient}; use clients_vtab::ClientsTable; use sqlite_loadable::{ - api, define_scalar_function, define_scalar_function_with_aux, define_virtual_table_writeablex, - prelude::*, Error, Result, + api, define_scalar_function, define_scalar_function_with_aux, define_table_function, + define_virtual_table_writeablex, prelude::*, Error, Result, }; use zerocopy::AsBytes; @@ -165,5 +167,6 @@ pub fn sqlite3_rembed_init(db: *mut sqlite3) -> Result<()> { flags, )?; define_virtual_table_writeablex::(db, "rembed_clients", Some(Rc::clone(&c)))?; + define_table_function::(db, "rembed_batch", Some(Rc::clone(&c)))?; Ok(()) } diff --git a/test.sql b/test.sql index d1e8e85..552ecb9 100644 --- a/test.sql +++ b/test.sql @@ -5,6 +5,72 @@ .timer on .echo on +INSERT INTO temp.rembed_clients(name, options) VALUES + ('snowflake-arctic-embed:s', 'ollama'); + +create table articles as + select column1 as headline + from (VALUES + ('Shohei Ohtani''s ex-interpreter pleads guilty to charges related to gambling and theft'), + ('The jury has been selected in Hunter Biden''s gun trial'), + ('Larry Allen, a Super Bowl champion and famed Dallas Cowboy, has died at age 52'), + ('After saying Charlotte, a lone stingray, was pregnant, aquarium now says she''s sick'), + ('An Epoch Times executive is facing money laundering charge'), + ('Hassan Nasrallah’s killing transforms an already deadly regional conflict'), + ('Who was Hassan Nasrallah, the Hezbollah leader killed by Israel?'), + ('What is Hezbollah, the militia fighting Israel in Lebanon?'), + ('Netanyahu defies calls for a cease-fire at the U.N., as Israel strikes Lebanon'), + ('Death toll from Hurricane Helene mounts as aftermath assessment begins'), + ('5 things to know from this week’s big report on cannabis'), + ('VP debates may alter a close race’s dynamic even when they don''t predict the winner'), + ('SpaceX launches ISS-bound crew that hopes to bring home 2 stuck astronauts'), + ('Why the price of eggs is on the rise again'), + ('A guide to your weekend viewing and reading'), + ('At the border in Arizona, Harris lays out a plan to get tough on fentanyl'), + ('A new kind of drug for schizophrenia promises fewer side effects'), + ('Meet the astronauts preparing to travel farther from Earth than any human before'), + ('‘SNL’ has always taken on politics. Here’s what works — and why'), + ('Golden-age rappers make a digital-age leap — and survive'), + ('Why Russia''s broadcaster RT turned to covertly funding American pro-Trump influencers'), + ('Read the indictment: NYC Mayor Eric Adams charged with bribery, fraud, foreign donations'), + ('Justice Department sues Alabama, claiming it purged voters too close to the election'), + ('Exactly 66 years ago, another Hurricane Helene rocked the Carolinas'), + ('A meteorologist in Atlanta rescued a woman from Helene floodwaters on camera') + ); + +select * from articles; + +.timer on +select headline, length(rembed('snowflake-arctic-embed:s', headline)) from articles; + +select contents, length(embedding) +from rembed_batch( + 'snowflake-arctic-embed:s', + ( + select json_group_array( + json_object( + 'id', rowid, + 'contents', headline + ) + ) from articles + ) +); +.exit + + +select * +from rembed_batch( + 'snowflake-arctic-embed:s', + json('[ + {"id": 1, "contents": "alex garcia"}, + {"id": 1, "contents": "joe biden"}, + {"id": 1, "contents": "kamala harris"} +]')); + + +.exit + + INSERT INTO temp.rembed_clients(name, options) VALUES ('text-embedding-3-small','openai'), ('jina-embeddings-v2-base-en','jina'),