Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DRAFT: batching support #14

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 220 additions & 0 deletions src/batch.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
use crate::{Client, FLOAT32_VECTOR_SUBTYPE};
use sqlite_loadable::{
api,
table::{ConstraintOperator, IndexInfo, VTab, VTabArguments, VTabCursor},
BestIndexError, Result,
};
use sqlite_loadable::{prelude::*, Error};
use std::{cell::RefCell, collections::HashMap, marker::PhantomData, mem, os::raw::c_int, rc::Rc};
use zerocopy::AsBytes;

static CREATE_SQL: &str =
"CREATE TABLE x(contents, embedding, input1 hidden, input2 hidden, source hidden)";
enum Columns {
Contents,
Embedding,
Input1,
Input2,
Source,
}
fn column(index: i32) -> Option<Columns> {
match index {
0 => Some(Columns::Contents),
1 => Some(Columns::Embedding),
2 => Some(Columns::Input1),
3 => Some(Columns::Input2),
4 => Some(Columns::Source),
_ => None,
}
}

#[repr(C)]
pub struct BatchTable {
/// must be first
base: sqlite3_vtab,
clients: Rc<RefCell<HashMap<String, Client>>>,
}

impl<'vtab> VTab<'vtab> for BatchTable {
type Aux = Rc<RefCell<HashMap<String, Client>>>;
type Cursor = BatchCursor<'vtab>;

fn connect(
_db: *mut sqlite3,
aux: Option<&Self::Aux>,
_args: VTabArguments,
) -> Result<(String, BatchTable)> {
let base: sqlite3_vtab = unsafe { mem::zeroed() };
let vtab = BatchTable {
base,
clients: aux.unwrap().clone(),
};
// TODO db.config(VTabConfig::Innocuous)?;
Ok((CREATE_SQL.to_owned(), vtab))
}
fn destroy(&self) -> Result<()> {
Ok(())
}

fn best_index(&self, mut info: IndexInfo) -> core::result::Result<(), BestIndexError> {
let mut has_input1 = false;
let mut has_input2 = false;
for mut constraint in info.constraints() {
match column(constraint.column_idx()) {
Some(Columns::Input1) => {
if constraint.usable() && constraint.op() == Some(ConstraintOperator::EQ) {
constraint.set_omit(true);
constraint.set_argv_index(1);
has_input1 = true;
} else {
return Err(BestIndexError::Constraint);
}
}
Some(Columns::Input2) => {
if constraint.usable() && constraint.op() == Some(ConstraintOperator::EQ) {
constraint.set_omit(true);
constraint.set_argv_index(2);
has_input2 = true;
} else {
return Err(BestIndexError::Constraint);
}
}
_ => (),
}
}
if !has_input1 {
return Err(BestIndexError::Error);
}
info.set_estimated_cost(100000.0);
info.set_estimated_rows(100000);
info.set_idxnum(2);

Ok(())
}

fn open(&mut self) -> Result<BatchCursor<'_>> {
Ok(BatchCursor::new(self.clients.clone()))
}
}

type Entry = (serde_json::Value, Vec<f32>);
#[repr(C)]
pub struct BatchCursor<'vtab> {
/// Base class. Must be first
base: sqlite3_vtab_cursor,
clients: Rc<RefCell<HashMap<String, Client>>>,
results: Option<Vec<Entry>>,
curr: usize,
phantom: PhantomData<&'vtab BatchTable>,
}
impl BatchCursor<'_> {
fn new<'vtab>(clients: Rc<RefCell<HashMap<String, Client>>>) -> BatchCursor<'vtab> {
let base: sqlite3_vtab_cursor = unsafe { mem::zeroed() };
BatchCursor {
base,
clients: clients,
results: None,
curr: 0,
phantom: PhantomData,
}
}
}

impl VTabCursor for BatchCursor<'_> {
fn filter(
&mut self,
_idx_num: c_int,
_idx_str: Option<&str>,
values: &[*mut sqlite3_value],
) -> Result<()> {
self.curr = 0;
let first = values.get(0).unwrap();
let (client_name, input) = match values.get(1) {
Some(v) => (api::value_text(first).unwrap(), api::value_text(v).unwrap()),
None => ("default", api::value_text(first).unwrap()),
};

let x = self.clients.borrow();
let client = x.get(client_name).ok_or_else(|| {
Error::new_message(format!(
"Client with name {client_name} was not registered with rembed_clients."
))
})?;

let input: serde_json::Value = serde_json::from_str(input).unwrap();
let input = input.as_array().unwrap();
let x: Vec<String> = input
.iter()
.map(|v| {
let contents = v.get("contents").unwrap().as_str().unwrap();
contents.to_string()
})
.collect();
let embeddings = match client {
Client::Ollama(c) => c.infer_multiple(x).unwrap(),
_ => todo!(),
};
self.results = Some(
embeddings
.iter()
.zip(input)
.map(|(emb, val)| (val.to_owned(), emb.to_owned()))
.collect(),
);

Ok(())
}

fn next(&mut self) -> Result<()> {
self.curr += 1;
Ok(())
}

fn eof(&self) -> bool {
self.results
.as_ref()
.map_or(true, |v| v.get(self.curr).is_none())
}

fn column(&self, context: *mut sqlite3_context, i: c_int) -> Result<()> {
match column(i) {
Some(Columns::Contents) => {
api::result_text(
context,
self.results
.as_ref()
.unwrap()
.get(self.curr)
.unwrap()
.0
.get("contents")
.unwrap()
.as_str()
.unwrap(),
)?;
}
Some(Columns::Embedding) => {
api::result_blob(
context,
self.results
.as_ref()
.unwrap()
.get(self.curr)
.unwrap()
.1
.as_bytes(),
);
api::result_subtype(context, FLOAT32_VECTOR_SUBTYPE);
}
Some(Columns::Input1) => todo!(),
Some(Columns::Input2) => todo!(),
Some(Columns::Source) => todo!(),
None => todo!(),
}
Ok(())
}

fn rowid(&self) -> Result<i64> {
Ok(self.curr as i64)
}
}
50 changes: 50 additions & 0 deletions src/clients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,26 @@ impl OllamaClient {
})?;
OllamaClient::parse_single_response(data)
}
pub fn infer_multiple(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>> {
let mut body = serde_json::Map::new();
body.insert("model".to_owned(), self.model.to_owned().into());

body.insert("input".to_owned(), input.into());
let body = serde_json::to_vec(&body).map_err(|error| {
Error::new_message(format!("Error serializing body to JSON: {error}"))
})?;

let data: serde_json::Value = ureq::post("http://localhost:11434/v1/embeddings")
.set("Content-Type", "application/json")
.send_bytes(body.as_ref())
.map_err(|error| Error::new_message(format!("Error sending HTTP request: {error}")))?
.into_json()
.map_err(|error| {
Error::new_message(format!("Error parsing HTTP response as JSON: {error}"))
})?;
OllamaClient::parse_multiple_response(data)
}

pub fn parse_single_response(value: serde_json::Value) -> Result<Vec<f32>> {
value
.get("embedding")
Expand All @@ -467,6 +487,36 @@ impl OllamaClient {
.collect()
})
}
pub fn parse_multiple_response(value: serde_json::Value) -> Result<Vec<Vec<f32>>> {
let data = value
.get("data")
.ok_or_else(|| Error::new_message("expected 'data' key in response body"))
.and_then(|v| {
v.as_array()
.ok_or_else(|| Error::new_message("expected 'data' path to be an array"))
})
.unwrap();

data.iter()
.map(|v| {
let embedding_object = v.as_object().unwrap();
embedding_object
.get("embedding")
.unwrap()
.as_array()
.unwrap()
.iter()
.map(|v| {
v.as_f64()
.ok_or_else(|| {
Error::new_message("expected 'embedding' array to contain floats")
})
.map(|f| f as f32)
})
.collect()
})
.collect()
}
}

#[derive(Clone)]
Expand Down
7 changes: 5 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
mod batch;
mod clients;
mod clients_vtab;

use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;

use batch::BatchTable;
use clients::{Client, CohereClient, LlamafileClient, NomicClient, OllamaClient, OpenAiClient};
use clients_vtab::ClientsTable;
use sqlite_loadable::{
api, define_scalar_function, define_scalar_function_with_aux, define_virtual_table_writeablex,
prelude::*, Error, Result,
api, define_scalar_function, define_scalar_function_with_aux, define_table_function,
define_virtual_table_writeablex, prelude::*, Error, Result,
};
use zerocopy::AsBytes;

Expand Down Expand Up @@ -165,5 +167,6 @@ pub fn sqlite3_rembed_init(db: *mut sqlite3) -> Result<()> {
flags,
)?;
define_virtual_table_writeablex::<ClientsTable>(db, "rembed_clients", Some(Rc::clone(&c)))?;
define_table_function::<BatchTable>(db, "rembed_batch", Some(Rc::clone(&c)))?;
Ok(())
}
66 changes: 66 additions & 0 deletions test.sql
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,72 @@
.timer on
.echo on

INSERT INTO temp.rembed_clients(name, options) VALUES
('snowflake-arctic-embed:s', 'ollama');

create table articles as
select column1 as headline
from (VALUES
('Shohei Ohtani''s ex-interpreter pleads guilty to charges related to gambling and theft'),
('The jury has been selected in Hunter Biden''s gun trial'),
('Larry Allen, a Super Bowl champion and famed Dallas Cowboy, has died at age 52'),
('After saying Charlotte, a lone stingray, was pregnant, aquarium now says she''s sick'),
('An Epoch Times executive is facing money laundering charge'),
('Hassan Nasrallah’s killing transforms an already deadly regional conflict'),
('Who was Hassan Nasrallah, the Hezbollah leader killed by Israel?'),
('What is Hezbollah, the militia fighting Israel in Lebanon?'),
('Netanyahu defies calls for a cease-fire at the U.N., as Israel strikes Lebanon'),
('Death toll from Hurricane Helene mounts as aftermath assessment begins'),
('5 things to know from this week’s big report on cannabis'),
('VP debates may alter a close race’s dynamic even when they don''t predict the winner'),
('SpaceX launches ISS-bound crew that hopes to bring home 2 stuck astronauts'),
('Why the price of eggs is on the rise again'),
('A guide to your weekend viewing and reading'),
('At the border in Arizona, Harris lays out a plan to get tough on fentanyl'),
('A new kind of drug for schizophrenia promises fewer side effects'),
('Meet the astronauts preparing to travel farther from Earth than any human before'),
('‘SNL’ has always taken on politics. Here’s what works — and why'),
('Golden-age rappers make a digital-age leap — and survive'),
('Why Russia''s broadcaster RT turned to covertly funding American pro-Trump influencers'),
('Read the indictment: NYC Mayor Eric Adams charged with bribery, fraud, foreign donations'),
('Justice Department sues Alabama, claiming it purged voters too close to the election'),
('Exactly 66 years ago, another Hurricane Helene rocked the Carolinas'),
('A meteorologist in Atlanta rescued a woman from Helene floodwaters on camera')
);

select * from articles;

.timer on
select headline, length(rembed('snowflake-arctic-embed:s', headline)) from articles;

select contents, length(embedding)
from rembed_batch(
'snowflake-arctic-embed:s',
(
select json_group_array(
json_object(
'id', rowid,
'contents', headline
)
) from articles
)
);
.exit


select *
from rembed_batch(
'snowflake-arctic-embed:s',
json('[
{"id": 1, "contents": "alex garcia"},
{"id": 1, "contents": "joe biden"},
{"id": 1, "contents": "kamala harris"}
]'));


.exit


INSERT INTO temp.rembed_clients(name, options) VALUES
('text-embedding-3-small','openai'),
('jina-embeddings-v2-base-en','jina'),
Expand Down