Skip to content

Commit

Permalink
add labels to embedding jobs, fix init_finished_at for empty tables
Browse files Browse the repository at this point in the history
  • Loading branch information
var77 committed Jun 7, 2024
1 parent 4de0b62 commit 468c240
Show file tree
Hide file tree
Showing 11 changed files with 183 additions and 13 deletions.
2 changes: 1 addition & 1 deletion lantern_cli/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "lantern_cli"
version = "0.3.0"
version = "0.3.1"
edition = "2021"

[[bin]]
Expand Down
1 change: 1 addition & 0 deletions lantern_cli/src/daemon/autotune_jobs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ pub async fn start(args: JobRunArgs, logger: Arc<Logger>, cancel_token: Cancella
Some(RESULT_TABLE_DEFINITION),
None,
None,
None,
&notification_channel,
logger.clone(),
)
Expand Down
4 changes: 4 additions & 0 deletions lantern_cli/src/daemon/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ pub struct DaemonArgs {
#[arg(short, long, default_value = "_lantern_internal")]
pub schema: String,

/// Label which will be matched against embedding job label
#[arg(long)]
pub label: Option<String>,

/// Log level
#[arg(long, value_enum, default_value_t = LogLevel::Info)] // arg_enum here
pub log_level: LogLevel,
Expand Down
78 changes: 67 additions & 11 deletions lantern_cli/src/daemon/embedding_jobs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use super::helpers::{
};
use super::types::{
ClientJobsMap, EmbeddingJob, JobBatchingHashMap, JobEvent, JobEventHandlersMap,
JobInsertNotification, JobRunArgs, JobUpdateNotification,
JobInsertNotification, JobLabelsMap, JobRunArgs, JobUpdateNotification,
};
use crate::daemon::helpers::anyhow_wrap_connection;
use crate::embeddings::cli::EmbeddingArgs;
Expand All @@ -29,6 +29,7 @@ pub const JOB_TABLE_DEFINITION: &'static str = r#"
"schema" text NOT NULL DEFAULT 'public',
"table" text NOT NULL,
"pk" text NOT NULL DEFAULT 'id',
"label" text NULL,
"runtime" text NOT NULL DEFAULT 'ort',
"runtime_params" jsonb,
"src_column" text NOT NULL,
Expand Down Expand Up @@ -169,6 +170,7 @@ async fn stream_job(
let job_clone = job.clone();
let jobs_map_clone = jobs_map.clone();
let client_jobs_map_clone = client_jobs_map.clone();
let jobs_table_name_clone = jobs_table_name.clone();

let task = tokio::spawn(async move {
logger.info(&format!("Start streaming job {}", job.id));
Expand Down Expand Up @@ -253,6 +255,17 @@ async fn stream_job(
let total_rows = total_rows as usize;

if total_rows == 0 {
if job.is_init {
transaction.execute(
&format!(
"UPDATE {jobs_table_name} SET init_finished_at=NOW(), updated_at=NOW(), init_progress=100 WHERE id=$1"
),
&[&job.id],
)
.await?;
transaction.commit().await?;
}

return Ok(());
}

Expand Down Expand Up @@ -334,7 +347,7 @@ async fn stream_job(
top_logger.error(&format!("Error while streaming job {job_id}: {e}"));
remove_job_handle(&jobs_map_clone, job_id).await?;
if job_clone.is_init {
main_client.execute(&format!("UPDATE {jobs_table_name} SET init_failed_at=NOW(), updated_at=NOW(), init_failure_reason=$1 WHERE id=$2 AND init_finished_at IS NULL"), &[&e.to_string(), &job_id]).await?;
main_client.execute(&format!("UPDATE {jobs_table_name_clone} SET init_failed_at=NOW(), updated_at=NOW(), init_failure_reason=$1 WHERE id=$2 AND init_finished_at IS NULL"), &[&e.to_string(), &job_id]).await?;
toggle_client_job(
client_jobs_map_clone.clone(),
job_clone.id,
Expand Down Expand Up @@ -556,10 +569,12 @@ async fn job_insert_processor(
db_uri: String,
schema: String,
table: String,
daemon_label: String,
data_path: &'static str,
jobs_map: Arc<JobEventHandlersMap>,
job_batching_hashmap: Arc<JobBatchingHashMap>,
client_jobs_map: Arc<ClientJobsMap>,
job_labels_map: Arc<JobLabelsMap>,
logger: Arc<Logger>,
) -> AnyhowVoidResult {
// This function will have 2 running tasks
Expand All @@ -579,8 +594,7 @@ async fn job_insert_processor(
// batch jobs for the rows. This will optimize embedding generation as if there will be lots of
// inserts to the table between 10 seconds all that rows will be batched.
let full_table_name = Arc::new(get_full_table_name(&schema, &table));
// TODO:: Select pk here
let job_query_sql = Arc::new(format!("SELECT id, src_column as \"column\", dst_column, \"table\", \"schema\", embedding_model as model, runtime, runtime_params::text, init_finished_at FROM {0}", &full_table_name));
let job_query_sql = Arc::new(format!("SELECT id, pk, label, src_column as \"column\", dst_column, \"table\", \"schema\", embedding_model as model, runtime, runtime_params::text, init_finished_at FROM {0}", &full_table_name));

let db_uri_r1 = db_uri.clone();
let full_table_name_r1 = full_table_name.clone();
Expand All @@ -589,6 +603,7 @@ async fn job_insert_processor(
let logger_r1 = logger.clone();
let lock_table_name = Arc::new(get_full_table_name(&schema, EMB_LOCK_TABLE_NAME));
let job_batching_hashmap_r1 = job_batching_hashmap.clone();
let job_labels_map_r1 = job_labels_map.clone();

let insert_processor_task = tokio::spawn(async move {
let (insert_client, connection) = tokio_postgres::connect(&db_uri_r1, NoTls).await?;
Expand All @@ -597,6 +612,14 @@ async fn job_insert_processor(
while let Some(notification) = notifications_rx.recv().await {
let id = notification.id;

// check if job's label is not matching the label of current daemon instance
// do not process the row
if let Some(label) = job_labels_map.read().await.get(&id) {
if label != &daemon_label {
continue;
}
}

if let Some(row_id) = notification.row_id {
// Do this in a non-blocking way to not block collecting of updates while locking
let client_r1 = insert_client.clone();
Expand Down Expand Up @@ -645,6 +668,24 @@ async fn job_insert_processor(
.get::<&str, Option<SystemTime>>("init_finished_at")
.is_none();

let job = EmbeddingJob::new(row, data_path, &db_uri_r1);

if let Err(e) = &job {
logger_r1.error(&format!("Error while creating job {id}: {e}",));
continue;
}

let mut job = job.unwrap();

if let Some(label) = &job.label {
// insert label in cache
job_labels_map.write().await.insert(job.id, label.clone());

if label != &daemon_label {
continue;
}
}

if is_init {
// Only update init time if this is the first time job is being executed
let updated_count = insert_client.execute(&format!("UPDATE {0} SET init_started_at=NOW() WHERE init_started_at IS NULL AND id=$1", &full_table_name_r1), &[&id]).await?;
Expand All @@ -653,13 +694,6 @@ async fn job_insert_processor(
}
}

let job = EmbeddingJob::new(row, data_path, &db_uri_r1);

if let Err(e) = &job {
logger_r1.error(&format!("Error while creating job {id}: {e}",));
continue;
}
let mut job = job.unwrap();
job.set_is_init(is_init);
if let Some(filter) = notification.filter {
job.set_filter(&filter);
Expand Down Expand Up @@ -715,6 +749,17 @@ async fn job_insert_processor(
continue;
}
let mut job = job.unwrap();

// update label in cache
if let Some(label) = &job.label {
job_labels_map_r1
.write()
.await
.insert(job.id, label.clone());
} else {
job_labels_map_r1.write().await.remove(&job.id);
}

job.set_is_init(false);
let rows_len = row_ids.len();
job.set_id_filter(&row_ids);
Expand Down Expand Up @@ -903,6 +948,13 @@ pub async fn start(
mpsc::channel(1);
let table = args.table_name;

//TODO:: Remove migration on next release
let migration_sql = format!(
"ALTER TABLE {full_table_name} ADD COLUMN IF NOT EXISTS label TEXT NULL;",
full_table_name = get_full_table_name(&args.schema, &table)
);
// ======================================

startup_hook(
&mut main_db_client,
&table,
Expand All @@ -913,6 +965,7 @@ pub async fn start(
None,
Some(EMB_USAGE_TABLE_NAME),
Some(USAGE_TABLE_DEFINITION),
Some(&migration_sql),
&notification_channel,
logger.clone(),
)
Expand All @@ -929,6 +982,7 @@ pub async fn start(

let jobs_map: Arc<JobEventHandlersMap> = Arc::new(RwLock::new(HashMap::new()));
let client_jobs_map: Arc<ClientJobsMap> = Arc::new(RwLock::new(HashMap::new()));
let job_labels_map: Arc<JobLabelsMap> = Arc::new(RwLock::new(HashMap::new()));

let job_batching_hashmap: Arc<JobBatchingHashMap> = Arc::new(Mutex::new(HashMap::new()));

Expand All @@ -949,10 +1003,12 @@ pub async fn start(
main_db_uri.clone(),
schema.clone(),
table.clone(),
args.label.clone().unwrap_or(String::new()),
data_path,
jobs_map.clone(),
job_batching_hashmap.clone(),
client_jobs_map.clone(),
job_labels_map,
logger.clone(),
),
job_update_processor(
Expand Down
1 change: 1 addition & 0 deletions lantern_cli/src/daemon/external_index_jobs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ pub async fn start(args: JobRunArgs, logger: Arc<Logger>, cancel_token: Cancella
None,
None,
None,
None,
&notification_channel,
logger.clone(),
)
Expand Down
5 changes: 5 additions & 0 deletions lantern_cli/src/daemon/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ pub async fn startup_hook(
results_table_def: Option<&str>,
usage_table_name: Option<&str>,
usage_table_def: Option<&str>,
migration: Option<&str>,
channel: &str,
logger: Arc<Logger>,
) -> AnyhowVoidResult {
Expand Down Expand Up @@ -156,6 +157,10 @@ pub async fn startup_hook(
)
.await?;

if let Some(migration_sql) = migration {
transaction.batch_execute(migration_sql).await?;
}

let insert_function_name = &get_full_table_name(schema, &format!("notify_insert_{table}"));
let update_function_name = &get_full_table_name(schema, &format!("notify_update_{table}"));
let insert_trigger_name = quote_ident(&format!("trigger_insert_{table}"));
Expand Down
3 changes: 3 additions & 0 deletions lantern_cli/src/daemon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ async fn spawn_job(
JobType::Embeddings => {
embedding_jobs::start(
JobRunArgs {
label: args.label.clone(),
uri: target_db.uri.clone(),
schema: args.schema.clone(),
log_level: args.log_level.value(),
Expand All @@ -109,6 +110,7 @@ async fn spawn_job(
JobType::ExternalIndex => {
external_index_jobs::start(
JobRunArgs {
label: args.label.clone(),
uri: target_db.uri.clone(),
schema: args.schema.clone(),
log_level: args.log_level.value(),
Expand All @@ -122,6 +124,7 @@ async fn spawn_job(
JobType::Autotune => {
autotune_jobs::start(
JobRunArgs {
label: args.label.clone(),
uri: target_db.uri.clone(),
schema: args.schema.clone(),
log_level: args.log_level.value(),
Expand Down
6 changes: 5 additions & 1 deletion lantern_cli/src/daemon/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub struct JobRunArgs {
pub schema: String,
pub log_level: crate::logger::LogLevel,
pub table_name: String,
pub label: Option<String>,
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -56,6 +57,7 @@ pub struct EmbeddingJob {
pub column: String,
pub pk: String,
pub filter: Option<String>,
pub label: Option<String>,
pub out_column: String,
pub model: String,
pub runtime_params: String,
Expand All @@ -78,7 +80,8 @@ impl EmbeddingJob {

Ok(Self {
id: row.get::<&str, i32>("id"),
pk: "id".to_owned(), // TODO:: row.get::<&str, String>("pk"),
pk: row.get::<&str, String>("pk"),
label: row.get::<&str, Option<String>>("label"),
db_uri: db_uri.to_owned(),
schema: row.get::<&str, String>("schema"),
table: row.get::<&str, String>("table"),
Expand Down Expand Up @@ -224,6 +227,7 @@ pub type JobEventHandlersMap = RwLock<HashMap<i32, JobTaskEventTx>>;
pub type JobBatchingHashMap = Mutex<HashMap<i32, Vec<String>>>;
pub type ClientJobsMap = RwLock<HashMap<i32, UnboundedSender<ClientJobSignal>>>;
pub type DaemonJobHandlerMap = RwLock<HashMap<String, CancellationToken>>;
pub type JobLabelsMap = RwLock<HashMap<i32, String>>;

pub enum JobEvent {
Done,
Expand Down
1 change: 1 addition & 0 deletions lantern_cli/tests/daemon_autotune_test_with_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ async fn test_daemon_autotune_with_create_index() {
tokio::spawn(async {
daemon::start(
DaemonArgs {
label: None,
master_db: None,
master_db_schema: String::new(),
embeddings: false,
Expand Down
Loading

0 comments on commit 468c240

Please sign in to comment.