From f5d2d18cdd4b5242a98f162a6d0ab8cdb66db5fb Mon Sep 17 00:00:00 2001 From: Lakshan Perera Date: Wed, 21 Feb 2024 10:41:58 +1100 Subject: [PATCH] fix: refactor Supabase AI API (#273) * fix: refactor sb_ai * fix: make the API async * fix: clippy --- Cargo.lock | 1 + crates/sb_ai/Cargo.toml | 1 + crates/sb_ai/ai.js | 15 ++- crates/sb_ai/lib.rs | 175 +++++++++++++++++++------------- crates/sb_core/js/bootstrap.js | 3 +- examples/gte-small-ort/index.ts | 4 +- 6 files changed, 124 insertions(+), 75 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 106e29be..2202bda0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4548,6 +4548,7 @@ dependencies = [ "rand", "serde", "tokenizers", + "tokio", ] [[package]] diff --git a/crates/sb_ai/Cargo.toml b/crates/sb_ai/Cargo.toml index cf2a1da8..df54a455 100644 --- a/crates/sb_ai/Cargo.toml +++ b/crates/sb_ai/Cargo.toml @@ -18,4 +18,5 @@ ort = { version = "2.0.0-alpha.4", default-features = false, features = [ "ndarr ndarray = "0.15" ndarray-linalg = "0.15" tokenizers = { version = ">=0.13.4", default-features = false, features = [ "onig" ] } +tokio.workspace = true rand = "0.8" diff --git a/crates/sb_ai/ai.js b/crates/sb_ai/ai.js index 3229c464..7336a031 100644 --- a/crates/sb_ai/ai.js +++ b/crates/sb_ai/ai.js @@ -1,10 +1,17 @@ const core = globalThis.Deno.core; -class SupabaseAI { - runModel(name, prompt) { - const result = core.ops.op_sb_ai_run_model(name, prompt); +class Session { + model; + + constructor(model) { + this.model = model; + core.ops.op_sb_ai_init_model(model); + } + + async run(prompt) { + const result = await core.ops.op_sb_ai_run_model(this.model, prompt); return result; } } -export { SupabaseAI }; +export default { Session }; diff --git a/crates/sb_ai/lib.rs b/crates/sb_ai/lib.rs index 333ec996..d1aa4170 100644 --- a/crates/sb_ai/lib.rs +++ b/crates/sb_ai/lib.rs @@ -5,106 +5,145 @@ use deno_core::OpState; use ndarray::{Array1, Array2, Axis, Ix2}; use ndarray_linalg::norm::{normalize, NormalizeAxis}; use ort::{inputs, GraphOptimizationLevel, Session, Tensor}; +use std::cell::RefCell; use std::path::Path; +use std::rc::Rc; use tokenizers::normalizers::bert::BertNormalizer; use tokenizers::Tokenizer; +use tokio::sync::mpsc; +use tokio::task; deno_core::extension!( sb_ai, - ops = [op_sb_ai_run_model], + ops = [op_sb_ai_run_model, op_sb_ai_init_model], esm_entry_point = "ext:sb_ai/ai.js", esm = ["ai.js",] ); -fn run_gte(state: &mut OpState, prompt: String) -> Result, Error> { +struct GteModelRequest { + prompt: String, + result_tx: mpsc::UnboundedSender>, +} + +fn init_gte(state: &mut OpState) -> Result<(), Error> { // Create the ONNX Runtime environment, for all sessions created in this process. ort::init().with_name("GTE").commit()?; let models_dir = std::env::var("SB_AI_MODELS_DIR").unwrap_or("/etc/sb_ai/models".to_string()); - let mut session = state.try_take::(); - if session.is_none() { - session = Some( - Session::builder()? - .with_optimization_level(GraphOptimizationLevel::Disable)? - .with_intra_threads(1)? - .with_model_from_file( - Path::new(&models_dir) - .join("gte") - .join("gte_small_quantized.onnx"), - )?, - ); - } - let session = session.unwrap(); + let (req_tx, mut req_rx) = mpsc::unbounded_channel::(); + state.put::>(req_tx); - // Load the tokenizer and encode the prompt into a sequence of tokens. - let mut tokenizer = state.try_take::(); - if tokenizer.is_none() { - tokenizer = Some( - Tokenizer::from_file( + #[allow(clippy::let_underscore_future)] + let _: task::JoinHandle> = task::spawn(async move { + let session = Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Disable)? + .with_intra_threads(1)? + .with_model_from_file( Path::new(&models_dir) .join("gte") - .join("gte_small_tokenizer.json"), - ) - .map_err(anyhow::Error::msg)?, - ) - } - let mut tokenizer = tokenizer.unwrap(); + .join("gte_small_quantized.onnx"), + )?; - let tokenizer_impl = tokenizer - .with_normalizer(BertNormalizer::default()) - .with_padding(None) - .with_truncation(None) + let mut tokenizer = Tokenizer::from_file( + Path::new(&models_dir) + .join("gte") + .join("gte_small_tokenizer.json"), + ) .map_err(anyhow::Error::msg)?; - let tokens = tokenizer_impl - .encode(prompt, true) - .map_err(anyhow::Error::msg)? - .get_ids() - .iter() - .map(|i| *i as i64) - .collect::>(); - - let tokens = Array1::from_iter(tokens.iter().cloned()); - - let array = tokens.view().insert_axis(Axis(0)); - let dims = array.raw_dim(); - let token_type_ids = Array2::::zeros(dims); - let attention_mask = Array2::::ones(dims); - let outputs = session.run(inputs! { - "input_ids" => array, - "token_type_ids" => token_type_ids, - "attention_mask" => attention_mask, - }?)?; - - let embeddings: Tensor = outputs["last_hidden_state"].extract_tensor()?; - - let embeddings_view = embeddings.view(); - let mean_pool = embeddings_view.mean_axis(Axis(1)).unwrap(); - let (normalized, _) = normalize( - mean_pool.into_dimensionality::().unwrap(), - NormalizeAxis::Row, - ); + let tokenizer_impl = tokenizer + .with_normalizer(BertNormalizer::default()) + .with_padding(None) + .with_truncation(None) + .map_err(anyhow::Error::msg)?; + + loop { + let req = req_rx.recv().await; + if req.is_none() { + break; + } + let req = req.unwrap(); + + let tokens = tokenizer_impl + .encode(req.prompt, true) + .map_err(anyhow::Error::msg)? + .get_ids() + .iter() + .map(|i| *i as i64) + .collect::>(); + + let tokens = Array1::from_iter(tokens.iter().cloned()); + let array = tokens.view().insert_axis(Axis(0)); + + let dims = array.raw_dim(); + let token_type_ids = Array2::::zeros(dims); + let attention_mask = Array2::::ones(dims); + let outputs = session.run(inputs! { + "input_ids" => array, + "token_type_ids" => token_type_ids, + "attention_mask" => attention_mask, + }?)?; + + let embeddings: Tensor = outputs["last_hidden_state"].extract_tensor()?; + + let embeddings_view = embeddings.view(); + let mean_pool = embeddings_view.mean_axis(Axis(1)).unwrap(); + let (normalized, _) = normalize( + mean_pool.into_dimensionality::().unwrap(), + NormalizeAxis::Row, + ); + + let result = normalized.view().to_slice().unwrap().to_vec(); + req.result_tx.send(result)?; + } + Ok(()) + }); + + Ok(()) +} - let slice = normalized.view().to_slice().unwrap().to_vec(); +async fn run_gte(state: Rc>, prompt: String) -> Result, Error> { + let req_tx; + { + let op_state = state.borrow(); + let maybe_req_tx = op_state.try_borrow::>(); + if maybe_req_tx.is_none() { + bail!("Run init model first") + } + req_tx = maybe_req_tx.unwrap().clone(); + } - drop(outputs); + let (result_tx, mut result_rx) = mpsc::unbounded_channel::>(); - state.put::(session); - state.put::(tokenizer); + req_tx.send(GteModelRequest { + prompt, + result_tx: result_tx.clone(), + })?; - Ok(slice) + let result = result_rx.recv().await; + Ok(result.unwrap()) } #[op2] #[serde] -pub fn op_sb_ai_run_model( - state: &mut OpState, +pub fn op_sb_ai_init_model(state: &mut OpState, #[string] name: String) -> Result<(), AnyError> { + if name == "gte-small" { + init_gte(state) + } else { + bail!("model not supported") + } +} + +#[op2(async)] +#[serde] +pub async fn op_sb_ai_run_model( + state: Rc>, #[string] name: String, #[string] prompt: String, ) -> Result, AnyError> { - if name == "gte" { - run_gte(state, prompt) + if name == "gte-small" { + run_gte(state, prompt).await } else { bail!("model not supported") } diff --git a/crates/sb_core/js/bootstrap.js b/crates/sb_core/js/bootstrap.js index bfff7fda..789757ac 100644 --- a/crates/sb_core/js/bootstrap.js +++ b/crates/sb_core/js/bootstrap.js @@ -20,7 +20,7 @@ import * as response from 'ext:deno_fetch/23_response.js'; import * as request from 'ext:deno_fetch/23_request.js'; import * as globalInterfaces from 'ext:deno_web/04_global_interfaces.js'; import { SUPABASE_ENV } from 'ext:sb_env/env.js'; -import { SupabaseAI } from 'ext:sb_ai/ai.js'; +import ai from 'ext:sb_ai/ai.js'; import { registerErrors } from 'ext:sb_core_main_js/js/errors.js'; import { formatException, @@ -326,7 +326,6 @@ globalThis.bootstrapSBEdge = ( ); setLanguage('en'); - const ai = new SupabaseAI(); Object.defineProperty(globalThis, 'Supabase_UNSTABLE', { get() { return { diff --git a/examples/gte-small-ort/index.ts b/examples/gte-small-ort/index.ts index b41988dd..7e1288c7 100644 --- a/examples/gte-small-ort/index.ts +++ b/examples/gte-small-ort/index.ts @@ -1,7 +1,9 @@ +const model = new Supabase_UNSTABLE.ai.Session('gte-small'); + Deno.serve(async (req: Request) => { const params = new URL(req.url).searchParams; const input = params.get('text'); - const output = Supabase_UNSTABLE.ai.runModel('gte', input); + const output = await model.run(input); return new Response( JSON.stringify( output,