Skip to content

Commit

Permalink
fix: refactor Supabase AI API (supabase#273)
Browse files Browse the repository at this point in the history
* fix: refactor sb_ai

* fix: make the API async

* fix: clippy
  • Loading branch information
laktek authored Feb 20, 2024
1 parent 1b92a67 commit f5d2d18
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 75 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/sb_ai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ ort = { version = "2.0.0-alpha.4", default-features = false, features = [ "ndarr
ndarray = "0.15"
ndarray-linalg = "0.15"
tokenizers = { version = ">=0.13.4", default-features = false, features = [ "onig" ] }
tokio.workspace = true
rand = "0.8"
15 changes: 11 additions & 4 deletions crates/sb_ai/ai.js
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
const core = globalThis.Deno.core;

class SupabaseAI {
runModel(name, prompt) {
const result = core.ops.op_sb_ai_run_model(name, prompt);
class Session {
model;

constructor(model) {
this.model = model;
core.ops.op_sb_ai_init_model(model);
}

async run(prompt) {
const result = await core.ops.op_sb_ai_run_model(this.model, prompt);
return result;
}
}

export { SupabaseAI };
export default { Session };
175 changes: 107 additions & 68 deletions crates/sb_ai/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,106 +5,145 @@ use deno_core::OpState;
use ndarray::{Array1, Array2, Axis, Ix2};
use ndarray_linalg::norm::{normalize, NormalizeAxis};
use ort::{inputs, GraphOptimizationLevel, Session, Tensor};
use std::cell::RefCell;
use std::path::Path;
use std::rc::Rc;
use tokenizers::normalizers::bert::BertNormalizer;
use tokenizers::Tokenizer;
use tokio::sync::mpsc;
use tokio::task;

deno_core::extension!(
sb_ai,
ops = [op_sb_ai_run_model],
ops = [op_sb_ai_run_model, op_sb_ai_init_model],
esm_entry_point = "ext:sb_ai/ai.js",
esm = ["ai.js",]
);

fn run_gte(state: &mut OpState, prompt: String) -> Result<Vec<f32>, Error> {
struct GteModelRequest {
prompt: String,
result_tx: mpsc::UnboundedSender<Vec<f32>>,
}

fn init_gte(state: &mut OpState) -> Result<(), Error> {
// Create the ONNX Runtime environment, for all sessions created in this process.
ort::init().with_name("GTE").commit()?;

let models_dir = std::env::var("SB_AI_MODELS_DIR").unwrap_or("/etc/sb_ai/models".to_string());

let mut session = state.try_take::<Session>();
if session.is_none() {
session = Some(
Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Disable)?
.with_intra_threads(1)?
.with_model_from_file(
Path::new(&models_dir)
.join("gte")
.join("gte_small_quantized.onnx"),
)?,
);
}
let session = session.unwrap();
let (req_tx, mut req_rx) = mpsc::unbounded_channel::<GteModelRequest>();
state.put::<mpsc::UnboundedSender<GteModelRequest>>(req_tx);

// Load the tokenizer and encode the prompt into a sequence of tokens.
let mut tokenizer = state.try_take::<Tokenizer>();
if tokenizer.is_none() {
tokenizer = Some(
Tokenizer::from_file(
#[allow(clippy::let_underscore_future)]
let _: task::JoinHandle<Result<(), Error>> = task::spawn(async move {
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Disable)?
.with_intra_threads(1)?
.with_model_from_file(
Path::new(&models_dir)
.join("gte")
.join("gte_small_tokenizer.json"),
)
.map_err(anyhow::Error::msg)?,
)
}
let mut tokenizer = tokenizer.unwrap();
.join("gte_small_quantized.onnx"),
)?;

let tokenizer_impl = tokenizer
.with_normalizer(BertNormalizer::default())
.with_padding(None)
.with_truncation(None)
let mut tokenizer = Tokenizer::from_file(
Path::new(&models_dir)
.join("gte")
.join("gte_small_tokenizer.json"),
)
.map_err(anyhow::Error::msg)?;

let tokens = tokenizer_impl
.encode(prompt, true)
.map_err(anyhow::Error::msg)?
.get_ids()
.iter()
.map(|i| *i as i64)
.collect::<Vec<_>>();

let tokens = Array1::from_iter(tokens.iter().cloned());

let array = tokens.view().insert_axis(Axis(0));
let dims = array.raw_dim();
let token_type_ids = Array2::<i64>::zeros(dims);
let attention_mask = Array2::<i64>::ones(dims);
let outputs = session.run(inputs! {
"input_ids" => array,
"token_type_ids" => token_type_ids,
"attention_mask" => attention_mask,
}?)?;

let embeddings: Tensor<f32> = outputs["last_hidden_state"].extract_tensor()?;

let embeddings_view = embeddings.view();
let mean_pool = embeddings_view.mean_axis(Axis(1)).unwrap();
let (normalized, _) = normalize(
mean_pool.into_dimensionality::<Ix2>().unwrap(),
NormalizeAxis::Row,
);
let tokenizer_impl = tokenizer
.with_normalizer(BertNormalizer::default())
.with_padding(None)
.with_truncation(None)
.map_err(anyhow::Error::msg)?;

loop {
let req = req_rx.recv().await;
if req.is_none() {
break;
}
let req = req.unwrap();

let tokens = tokenizer_impl
.encode(req.prompt, true)
.map_err(anyhow::Error::msg)?
.get_ids()
.iter()
.map(|i| *i as i64)
.collect::<Vec<_>>();

let tokens = Array1::from_iter(tokens.iter().cloned());
let array = tokens.view().insert_axis(Axis(0));

let dims = array.raw_dim();
let token_type_ids = Array2::<i64>::zeros(dims);
let attention_mask = Array2::<i64>::ones(dims);
let outputs = session.run(inputs! {
"input_ids" => array,
"token_type_ids" => token_type_ids,
"attention_mask" => attention_mask,
}?)?;

let embeddings: Tensor<f32> = outputs["last_hidden_state"].extract_tensor()?;

let embeddings_view = embeddings.view();
let mean_pool = embeddings_view.mean_axis(Axis(1)).unwrap();
let (normalized, _) = normalize(
mean_pool.into_dimensionality::<Ix2>().unwrap(),
NormalizeAxis::Row,
);

let result = normalized.view().to_slice().unwrap().to_vec();
req.result_tx.send(result)?;
}
Ok(())
});

Ok(())
}

let slice = normalized.view().to_slice().unwrap().to_vec();
async fn run_gte(state: Rc<RefCell<OpState>>, prompt: String) -> Result<Vec<f32>, Error> {
let req_tx;
{
let op_state = state.borrow();
let maybe_req_tx = op_state.try_borrow::<mpsc::UnboundedSender<GteModelRequest>>();
if maybe_req_tx.is_none() {
bail!("Run init model first")
}
req_tx = maybe_req_tx.unwrap().clone();
}

drop(outputs);
let (result_tx, mut result_rx) = mpsc::unbounded_channel::<Vec<f32>>();

state.put::<Session>(session);
state.put::<Tokenizer>(tokenizer);
req_tx.send(GteModelRequest {
prompt,
result_tx: result_tx.clone(),
})?;

Ok(slice)
let result = result_rx.recv().await;
Ok(result.unwrap())
}

#[op2]
#[serde]
pub fn op_sb_ai_run_model(
state: &mut OpState,
pub fn op_sb_ai_init_model(state: &mut OpState, #[string] name: String) -> Result<(), AnyError> {
if name == "gte-small" {
init_gte(state)
} else {
bail!("model not supported")
}
}

#[op2(async)]
#[serde]
pub async fn op_sb_ai_run_model(
state: Rc<RefCell<OpState>>,
#[string] name: String,
#[string] prompt: String,
) -> Result<Vec<f32>, AnyError> {
if name == "gte" {
run_gte(state, prompt)
if name == "gte-small" {
run_gte(state, prompt).await
} else {
bail!("model not supported")
}
Expand Down
3 changes: 1 addition & 2 deletions crates/sb_core/js/bootstrap.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import * as response from 'ext:deno_fetch/23_response.js';
import * as request from 'ext:deno_fetch/23_request.js';
import * as globalInterfaces from 'ext:deno_web/04_global_interfaces.js';
import { SUPABASE_ENV } from 'ext:sb_env/env.js';
import { SupabaseAI } from 'ext:sb_ai/ai.js';
import ai from 'ext:sb_ai/ai.js';
import { registerErrors } from 'ext:sb_core_main_js/js/errors.js';
import {
formatException,
Expand Down Expand Up @@ -326,7 +326,6 @@ globalThis.bootstrapSBEdge = (
);
setLanguage('en');

const ai = new SupabaseAI();
Object.defineProperty(globalThis, 'Supabase_UNSTABLE', {
get() {
return {
Expand Down
4 changes: 3 additions & 1 deletion examples/gte-small-ort/index.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
const model = new Supabase_UNSTABLE.ai.Session('gte-small');

Deno.serve(async (req: Request) => {
const params = new URL(req.url).searchParams;
const input = params.get('text');
const output = Supabase_UNSTABLE.ai.runModel('gte', input);
const output = await model.run(input);
return new Response(
JSON.stringify(
output,
Expand Down

0 comments on commit f5d2d18

Please sign in to comment.