Skip to content

Commit

Permalink
Optimize relu max pooling and fix sql query for non-integer primary k…
Browse files Browse the repository at this point in the history
…eys in embedding jobs
  • Loading branch information
var77 committed Jul 6, 2024
1 parent 9c258af commit 2e89f8d
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/publish-cli-docker.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ on:
type: string
description: "CLI version"
required: true
default: "0.3.8"
default: "0.3.9"
IMAGE_NAME:
type: string
description: "Container image name to tag"
Expand Down
2 changes: 1 addition & 1 deletion lantern_cli/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "lantern_cli"
version = "0.3.8"
version = "0.3.9"
edition = "2021"

[[bin]]
Expand Down
4 changes: 2 additions & 2 deletions lantern_cli/src/daemon/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::external_index::cli::CreateIndexArgs;
use crate::index_autotune::cli::IndexAutotuneArgs;
use crate::logger::Logger;
use crate::types::{AnyhowVoidResult, ProgressCbFn};
use crate::utils::get_common_embedding_ignore_filters;
use crate::utils::{get_common_embedding_ignore_filters, quote_literal};
use itertools::Itertools;
use std::collections::HashMap;
use std::sync::Arc;
Expand Down Expand Up @@ -127,7 +127,7 @@ impl EmbeddingJob {
}

pub fn set_id_filter(&mut self, row_ids: &Vec<String>) {
let row_ctids_str = row_ids.iter().join(",");
let row_ctids_str = row_ids.iter().map(|s| quote_literal(s)).join(",");
self.set_filter(&format!(
"id IN ({row_ctids_str}) AND {common_filter}",
common_filter = get_common_embedding_ignore_filters(&self.column)
Expand Down
5 changes: 1 addition & 4 deletions lantern_cli/src/embeddings/core/ort_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,7 @@ impl PoolingStrategy {
output_dims: usize,
) -> Vec<Vec<f32>> {
// Apply ReLU: max(0, x)
let relu_embeddings = embeddings.mapv(|x| x.max(0.0));

// Apply log(1 + x)
let relu_log_embeddings = relu_embeddings.mapv(|x| (1.0 + x).ln());
let relu_log_embeddings = embeddings.mapv(|x| (1.0 + x.max(0.0)).ln());

// Expand attention mask to match embeddings dimensions
let attention_mask_shape = attention_mask.shape();
Expand Down
2 changes: 1 addition & 1 deletion lantern_cli/src/embeddings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ pub fn get_default_batch_size(model: &str) -> usize {
"thenlper/gte-base" => 1000,
"thenlper/gte-large" => 800,
"microsoft/all-MiniLM-L12-v2" => 1000,
"naver/splade-v3" => 100,
"naver/splade-v3" => 30,
"microsoft/all-mpnet-base-v2" => 400,
"transformers/multi-qa-mpnet-base-dot-v1" => 300,
"openai/text-embedding-ada-002" => 500,
Expand Down
5 changes: 5 additions & 0 deletions lantern_cli/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ pub fn quote_ident(str: &str) -> String {
format!("\"{}\"", str.replace("\"", "\"\""))
}

pub fn quote_literal(str: &str) -> String {
let replaced = str.replace("'", "''");
format!("'{}'", replaced)
}

pub fn get_full_table_name(schema: &str, table: &str) -> String {
let schema = quote_ident(schema);
let table = quote_ident(table);
Expand Down
114 changes: 114 additions & 0 deletions lantern_cli/tests/daemon_embeddings_test_with_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -860,3 +860,117 @@ async fn test_daemon_embedding_init_job_streaming_large() {

cancel_token.cancel();
}

#[tokio::test]
async fn test_daemon_embedding_job_text_pk() {
let (new_connection_uri, mut new_db_client) = setup_test("test_daemon_embedding_job_text_pk")
.await
.unwrap();
new_db_client
.batch_execute(&format!(
r#"
CREATE TABLE {CLIENT_TABLE_NAME_2} (id TEXT PRIMARY KEY, title TEXT, title_embedding REAL[]);
INSERT INTO {CLIENT_TABLE_NAME_2} (id, title)
VALUES ('id1', 'Test1'),
('id2','Test2'),
('id3','Test3'),
('id4','Test4'),
('id5','Test5');
INSERT INTO _lantern_extras_internal.embedding_generation_jobs ("id", "table", src_column, dst_column, embedding_model)
VALUES (15, '{CLIENT_TABLE_NAME_2}', 'title', 'title_embedding', 'BAAI/bge-small-en');
"#
))
.await
.unwrap();
let cancel_token = CancellationToken::new();
let cancel_token_clone = cancel_token.clone();

tokio::spawn(async {
daemon::start(
DaemonArgs {
label: None,
master_db: None,
master_db_schema: String::new(),
embeddings: true,
autotune: false,
external_index: false,
databases_table: String::new(),
schema: "_lantern_extras_internal".to_owned(),
target_db: Some(vec![new_connection_uri]),
log_level: LogLevel::Debug,
},
None,
cancel_token_clone,
)
.await
.unwrap();
});

wait_for_completion(
&mut new_db_client,
&format!("SELECT COUNT(*)=5 FROM {CLIENT_TABLE_NAME_2} WHERE title_embedding IS NOT NULL"),
30,
)
.await
.unwrap();

cancel_token.cancel();
}

#[tokio::test]
async fn test_daemon_embedding_job_uuid_pk() {
let (new_connection_uri, mut new_db_client) = setup_test("test_daemon_embedding_job_uuid_pk")
.await
.unwrap();
new_db_client
.batch_execute(&format!(
r#"
CREATE TABLE {CLIENT_TABLE_NAME_2} (id UUID PRIMARY KEY DEFAULT gen_random_uuid(), title TEXT, title_embedding REAL[]);
INSERT INTO {CLIENT_TABLE_NAME_2} (title)
VALUES ('Test1'),
('Test2'),
('Test3'),
('Test4'),
('Test5');
INSERT INTO _lantern_extras_internal.embedding_generation_jobs ("id", "table", src_column, dst_column, embedding_model)
VALUES (16, '{CLIENT_TABLE_NAME_2}', 'title', 'title_embedding', 'BAAI/bge-small-en');
"#
))
.await
.unwrap();
let cancel_token = CancellationToken::new();
let cancel_token_clone = cancel_token.clone();

tokio::spawn(async {
daemon::start(
DaemonArgs {
label: None,
master_db: None,
master_db_schema: String::new(),
embeddings: true,
autotune: false,
external_index: false,
databases_table: String::new(),
schema: "_lantern_extras_internal".to_owned(),
target_db: Some(vec![new_connection_uri]),
log_level: LogLevel::Debug,
},
None,
cancel_token_clone,
)
.await
.unwrap();
});

wait_for_completion(
&mut new_db_client,
&format!("SELECT COUNT(*)=5 FROM {CLIENT_TABLE_NAME_2} WHERE title_embedding IS NOT NULL"),
30,
)
.await
.unwrap();

cancel_token.cancel();
}

0 comments on commit 2e89f8d

Please sign in to comment.