From 6258320f9a363600a37585fb1a38b9d54554a8c0 Mon Sep 17 00:00:00 2001 From: Lia Lenckowski Date: Mon, 25 Nov 2024 19:07:36 +0100 Subject: [PATCH 1/6] for image embeddings, turn some panics into returned errors --- src/image_embedding/impl.rs | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/src/image_embedding/impl.rs b/src/image_embedding/impl.rs index a01c3a1..5253b7d 100644 --- a/src/image_embedding/impl.rs +++ b/src/image_embedding/impl.rs @@ -16,7 +16,7 @@ use crate::{ common::normalize, models::image_embedding::models_list, Embedding, ImageEmbeddingModel, ModelInfo, }; -use anyhow::anyhow; +use anyhow::{anyhow, Context}; #[cfg(feature = "online")] use super::ImageInitOptions; @@ -52,13 +52,13 @@ impl ImageEmbedding { let preprocessor_file = model_repo .get("preprocessor_config.json") - .unwrap_or_else(|_| panic!("Failed to retrieve preprocessor_config.json")); + .context("Failed to retrieve preprocessor_config.json")?; let preprocessor = Compose::from_file(preprocessor_file)?; let model_file_name = ImageEmbedding::get_model_info(&model_name).model_file; let model_file_reference = model_repo .get(&model_file_name) - .unwrap_or_else(|_| panic!("Failed to retrieve {} ", model_file_name)); + .context(format!("Failed to retrieve {}", model_file_name))?; let session = Session::builder()? .with_execution_providers(execution_providers)? @@ -141,7 +141,7 @@ impl ImageEmbedding { // Determine the batch size, default if not specified let batch_size = batch_size.unwrap_or(DEFAULT_BATCH_SIZE); - let output = images + images .par_chunks(batch_size) .map(|batch| { // Encode the texts in the batch @@ -189,9 +189,21 @@ impl ImageEmbedding { Ok(embeddings) }) - .flat_map(|result: Result>, anyhow::Error>| result.unwrap()) - .collect(); - - Ok(output) + .try_fold( + || vec![], + |mut a, result| { + result.map(|mut es| { + a.append(&mut es); + a + }) + }, + ) + .try_reduce( + || vec![], + |mut a, mut b| { + a.append(&mut b); + Ok(a) + }, + ) } } From ebfa23eacac0747713cae1e4e003665b2a4697a6 Mon Sep 17 00:00:00 2001 From: Lia Lenckowski Date: Mon, 25 Nov 2024 21:56:57 +0100 Subject: [PATCH 2/6] turn more panics into recoverable errors --- src/image_embedding/impl.rs | 3 +- src/reranking/impl.rs | 17 ++-- src/sparse_text_embedding/impl.rs | 7 +- src/text_embedding/impl.rs | 144 +++++++++++++++--------------- 4 files changed, 84 insertions(+), 87 deletions(-) diff --git a/src/image_embedding/impl.rs b/src/image_embedding/impl.rs index 5253b7d..529ebd7 100644 --- a/src/image_embedding/impl.rs +++ b/src/image_embedding/impl.rs @@ -111,8 +111,7 @@ impl ImageEmbedding { let cache = Cache::new(cache_dir); let api = ApiBuilder::from_cache(cache) .with_progress(show_download_progress) - .build() - .unwrap(); + .build()?; let repo = api.model(model.to_string()); Ok(repo) diff --git a/src/reranking/impl.rs b/src/reranking/impl.rs index 853bfc6..8599024 100644 --- a/src/reranking/impl.rs +++ b/src/reranking/impl.rs @@ -1,4 +1,4 @@ -use anyhow::Result; +use anyhow::{Context, Result}; use ort::{ session::{builder::GraphOptimizationLevel, Session}, value::Value, @@ -70,15 +70,16 @@ impl TextRerank { let model_repo = api.model(model_name.to_string()); let model_file_name = TextRerank::get_model_info(&model_name).model_file; - let model_file_reference = model_repo - .get(&model_file_name) - .unwrap_or_else(|_| panic!("Failed to retrieve model file: {}", model_file_name)); + let model_file_reference = model_repo.get(&model_file_name).context(format!( + "Failed to retrieve model file: {}", + model_file_name + ))?; let additional_files = TextRerank::get_model_info(&model_name).additional_files; for additional_file in additional_files { - let _additional_file_reference = - model_repo.get(&additional_file).unwrap_or_else(|_| { - panic!("Failed to retrieve additional file: {}", additional_file) - }); + let _additional_file_reference = model_repo.get(&additional_file).context(format!( + "Failed to retrieve additional file: {}", + additional_file + ))?; } let session = Session::builder()? diff --git a/src/sparse_text_embedding/impl.rs b/src/sparse_text_embedding/impl.rs index 444ade6..e9c943b 100644 --- a/src/sparse_text_embedding/impl.rs +++ b/src/sparse_text_embedding/impl.rs @@ -4,7 +4,7 @@ use crate::{ models::sparse::{models_list, SparseModel}, ModelInfo, SparseEmbedding, }; -use anyhow::Result; +use anyhow::{Context, Result}; #[cfg(feature = "online")] use hf_hub::{ api::sync::{ApiBuilder, ApiRepo}, @@ -55,7 +55,7 @@ impl SparseTextEmbedding { let model_file_name = SparseTextEmbedding::get_model_info(&model_name).model_file; let model_file_reference = model_repo .get(&model_file_name) - .unwrap_or_else(|_| panic!("Failed to retrieve {} ", model_file_name)); + .context(format!("Failed to retrieve {} ", model_file_name))?; let session = Session::builder()? .with_execution_providers(execution_providers)? @@ -91,8 +91,7 @@ impl SparseTextEmbedding { let cache = Cache::new(cache_dir); let api = ApiBuilder::from_cache(cache) .with_progress(show_download_progress) - .build() - .unwrap(); + .build()?; let repo = api.model(model.to_string()); Ok(repo) diff --git a/src/text_embedding/impl.rs b/src/text_embedding/impl.rs index 7b15804..fa510a4 100644 --- a/src/text_embedding/impl.rs +++ b/src/text_embedding/impl.rs @@ -8,6 +8,7 @@ use crate::{ pooling::Pooling, Embedding, EmbeddingModel, EmbeddingOutput, ModelInfo, QuantizationMode, SingleBatchOutput, }; +use anyhow::{Context, Result}; #[cfg(feature = "online")] use hf_hub::{ api::sync::{ApiBuilder, ApiRepo}, @@ -40,7 +41,7 @@ impl TextEmbedding { /// /// Uses the total number of CPUs available as the number of intra-threads #[cfg(feature = "online")] - pub fn try_new(options: InitOptions) -> anyhow::Result { + pub fn try_new(options: InitOptions) -> Result { let InitOptions { model_name, execution_providers, @@ -61,7 +62,7 @@ impl TextEmbedding { let model_file_name = &model_info.model_file; let model_file_reference = model_repo .get(model_file_name) - .unwrap_or_else(|_| panic!("Failed to retrieve {} ", model_file_name)); + .context(format!("Failed to retrieve {}", model_file_name))?; // TODO: If more models need .onnx_data, implement a better way to handle this // Probably by adding `additional_files` field in the `ModelInfo` struct @@ -95,7 +96,7 @@ impl TextEmbedding { pub fn try_new_from_user_defined( model: UserDefinedEmbeddingModel, options: InitOptionsUserDefined, - ) -> anyhow::Result { + ) -> Result { let InitOptionsUserDefined { execution_providers, max_length, @@ -147,8 +148,7 @@ impl TextEmbedding { let cache = Cache::new(cache_dir); let api = ApiBuilder::from_cache(cache) .with_progress(show_download_progress) - .build() - .unwrap(); + .build()?; let repo = api.model(model.to_string()); Ok(repo) @@ -160,7 +160,7 @@ impl TextEmbedding { } /// Get ModelInfo from EmbeddingModel - pub fn get_model_info(model: &EmbeddingModel) -> anyhow::Result<&ModelInfo> { + pub fn get_model_info(model: &EmbeddingModel) -> Result<&ModelInfo> { get_model_info(model).ok_or_else(|| { anyhow::Error::msg(format!( "Model {model:?} not found. Please check if the model is supported \ @@ -195,7 +195,7 @@ impl TextEmbedding { &'e self, texts: Vec, batch_size: Option, - ) -> anyhow::Result> + ) -> Result> where 'e: 'r, 'e: 's, @@ -223,72 +223,70 @@ impl TextEmbedding { _ => Ok(batch_size.unwrap_or(DEFAULT_BATCH_SIZE)), }?; - let batches = - anyhow::Result::>::from_par_iter(texts.par_chunks(batch_size).map(|batch| { - // Encode the texts in the batch - let inputs = batch.iter().map(|text| text.as_ref()).collect(); - let encodings = self.tokenizer.encode_batch(inputs, true).map_err(|e| { - anyhow::Error::msg(e.to_string()).context("Failed to encode the batch.") - })?; - - // Extract the encoding length and batch size - let encoding_length = encodings[0].len(); - let batch_size = batch.len(); - - let max_size = encoding_length * batch_size; - - // Preallocate arrays with the maximum size - let mut ids_array = Vec::with_capacity(max_size); - let mut mask_array = Vec::with_capacity(max_size); - let mut typeids_array = Vec::with_capacity(max_size); - - // Not using par_iter because the closure needs to be FnMut - encodings.iter().for_each(|encoding| { - let ids = encoding.get_ids(); - let mask = encoding.get_attention_mask(); - let typeids = encoding.get_type_ids(); - - // Extend the preallocated arrays with the current encoding - // Requires the closure to be FnMut - ids_array.extend(ids.iter().map(|x| *x as i64)); - mask_array.extend(mask.iter().map(|x| *x as i64)); - typeids_array.extend(typeids.iter().map(|x| *x as i64)); - }); - - // Create CowArrays from vectors - let inputs_ids_array = - Array::from_shape_vec((batch_size, encoding_length), ids_array)?; - - let attention_mask_array = - Array::from_shape_vec((batch_size, encoding_length), mask_array)?; - - let token_type_ids_array = - Array::from_shape_vec((batch_size, encoding_length), typeids_array)?; - - let mut session_inputs = ort::inputs![ - "input_ids" => Value::from_array(inputs_ids_array)?, - "attention_mask" => Value::from_array(attention_mask_array.view())?, - ]?; - - if self.need_token_type_ids { - session_inputs.push(( - "token_type_ids".into(), - Value::from_array(token_type_ids_array)?.into(), - )); - } + let batches = Result::>::from_par_iter(texts.par_chunks(batch_size).map(|batch| { + // Encode the texts in the batch + let inputs = batch.iter().map(|text| text.as_ref()).collect(); + let encodings = self.tokenizer.encode_batch(inputs, true).map_err(|e| { + anyhow::Error::msg(e.to_string()).context("Failed to encode the batch.") + })?; + + // Extract the encoding length and batch size + let encoding_length = encodings[0].len(); + let batch_size = batch.len(); + + let max_size = encoding_length * batch_size; + + // Preallocate arrays with the maximum size + let mut ids_array = Vec::with_capacity(max_size); + let mut mask_array = Vec::with_capacity(max_size); + let mut typeids_array = Vec::with_capacity(max_size); + + // Not using par_iter because the closure needs to be FnMut + encodings.iter().for_each(|encoding| { + let ids = encoding.get_ids(); + let mask = encoding.get_attention_mask(); + let typeids = encoding.get_type_ids(); + + // Extend the preallocated arrays with the current encoding + // Requires the closure to be FnMut + ids_array.extend(ids.iter().map(|x| *x as i64)); + mask_array.extend(mask.iter().map(|x| *x as i64)); + typeids_array.extend(typeids.iter().map(|x| *x as i64)); + }); + + // Create CowArrays from vectors + let inputs_ids_array = Array::from_shape_vec((batch_size, encoding_length), ids_array)?; + + let attention_mask_array = + Array::from_shape_vec((batch_size, encoding_length), mask_array)?; + + let token_type_ids_array = + Array::from_shape_vec((batch_size, encoding_length), typeids_array)?; + + let mut session_inputs = ort::inputs![ + "input_ids" => Value::from_array(inputs_ids_array)?, + "attention_mask" => Value::from_array(attention_mask_array.view())?, + ]?; + + if self.need_token_type_ids { + session_inputs.push(( + "token_type_ids".into(), + Value::from_array(token_type_ids_array)?.into(), + )); + } - Ok( - // Package all the data required for post-processing (e.g. pooling) - // into a SingleBatchOutput struct. - SingleBatchOutput { - session_outputs: self - .session - .run(session_inputs) - .map_err(anyhow::Error::new)?, - attention_mask_array, - }, - ) - }))?; + Ok( + // Package all the data required for post-processing (e.g. pooling) + // into a SingleBatchOutput struct. + SingleBatchOutput { + session_outputs: self + .session + .run(session_inputs) + .map_err(anyhow::Error::new)?, + attention_mask_array, + }, + ) + }))?; Ok(EmbeddingOutput::new(batches)) } @@ -308,7 +306,7 @@ impl TextEmbedding { &self, texts: Vec, batch_size: Option, - ) -> anyhow::Result> { + ) -> Result> { let batches = self.transform(texts, batch_size)?; batches.export_with_transformer(output::transformer_with_precedence( From 96fdbf3bb4db83e8ea6b8606556fe55581ec3244 Mon Sep 17 00:00:00 2001 From: llenck Date: Tue, 26 Nov 2024 08:31:17 +0100 Subject: [PATCH 3/6] Simplify collection of image embeddings Co-authored-by: Anush --- src/image_embedding/impl.rs | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/src/image_embedding/impl.rs b/src/image_embedding/impl.rs index 529ebd7..788cfe4 100644 --- a/src/image_embedding/impl.rs +++ b/src/image_embedding/impl.rs @@ -188,21 +188,9 @@ impl ImageEmbedding { Ok(embeddings) }) - .try_fold( - || vec![], - |mut a, result| { - result.map(|mut es| { - a.append(&mut es); - a - }) - }, - ) - .try_reduce( - || vec![], - |mut a, mut b| { - a.append(&mut b); - Ok(a) - }, - ) + .collect::, Error>>()? + .into_iter() + .flatten() + .collect(); } } From 1dc2115eaaaf74565d6923b7694f2aa977205cbc Mon Sep 17 00:00:00 2001 From: Lia Lenckowski Date: Tue, 26 Nov 2024 09:08:07 +0100 Subject: [PATCH 4/6] Pass errors in sparse text embedding and reranking batches onto the user instead of panicking --- src/image_embedding/impl.rs | 6 ++++-- src/reranking/impl.rs | 4 +++- src/sparse_text_embedding/impl.rs | 4 +++- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/image_embedding/impl.rs b/src/image_embedding/impl.rs index 788cfe4..76bd04d 100644 --- a/src/image_embedding/impl.rs +++ b/src/image_embedding/impl.rs @@ -140,7 +140,7 @@ impl ImageEmbedding { // Determine the batch size, default if not specified let batch_size = batch_size.unwrap_or(DEFAULT_BATCH_SIZE); - images + let output = images .par_chunks(batch_size) .map(|batch| { // Encode the texts in the batch @@ -188,9 +188,11 @@ impl ImageEmbedding { Ok(embeddings) }) - .collect::, Error>>()? + .collect::>>()? .into_iter() .flatten() .collect(); + + Ok(output) } } diff --git a/src/reranking/impl.rs b/src/reranking/impl.rs index 8599024..76aba79 100644 --- a/src/reranking/impl.rs +++ b/src/reranking/impl.rs @@ -197,7 +197,9 @@ impl TextRerank { Ok(scores) }) - .flat_map(|result: Result, anyhow::Error>| result.unwrap()) + .collect::>>()? + .into_iter() + .flatten() .collect(); // Return top_n_result of type Vec ordered by score in descending order, don't use binary heap diff --git a/src/sparse_text_embedding/impl.rs b/src/sparse_text_embedding/impl.rs index e9c943b..a6b29a0 100644 --- a/src/sparse_text_embedding/impl.rs +++ b/src/sparse_text_embedding/impl.rs @@ -188,7 +188,9 @@ impl SparseTextEmbedding { Ok(embeddings) }) - .flat_map(|result: Result, anyhow::Error>| result.unwrap()) + .collect::>>()? + .into_iter() + .flatten() .collect(); Ok(output) From 2d78e01efacdd0e954dfbf2ae5f917e2a041564b Mon Sep 17 00:00:00 2001 From: Lia Lenckowski Date: Tue, 26 Nov 2024 09:38:52 +0100 Subject: [PATCH 5/6] fix warnings: only import anyhow::Context if online feature enabled --- src/image_embedding/impl.rs | 4 +++- src/reranking/impl.rs | 4 +++- src/sparse_text_embedding/impl.rs | 4 +++- src/text_embedding/impl.rs | 4 +++- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/image_embedding/impl.rs b/src/image_embedding/impl.rs index 76bd04d..d614971 100644 --- a/src/image_embedding/impl.rs +++ b/src/image_embedding/impl.rs @@ -16,7 +16,9 @@ use crate::{ common::normalize, models::image_embedding::models_list, Embedding, ImageEmbeddingModel, ModelInfo, }; -use anyhow::{anyhow, Context}; +#[cfg(feature = "online")] +use anyhow::Context; +use anyhow::anyhow; #[cfg(feature = "online")] use super::ImageInitOptions; diff --git a/src/reranking/impl.rs b/src/reranking/impl.rs index 76aba79..cc458fb 100644 --- a/src/reranking/impl.rs +++ b/src/reranking/impl.rs @@ -1,4 +1,6 @@ -use anyhow::{Context, Result}; +#[cfg(feature = "online")] +use anyhow::Context; +use anyhow::Result; use ort::{ session::{builder::GraphOptimizationLevel, Session}, value::Value, diff --git a/src/sparse_text_embedding/impl.rs b/src/sparse_text_embedding/impl.rs index a6b29a0..2bd4256 100644 --- a/src/sparse_text_embedding/impl.rs +++ b/src/sparse_text_embedding/impl.rs @@ -4,7 +4,9 @@ use crate::{ models::sparse::{models_list, SparseModel}, ModelInfo, SparseEmbedding, }; -use anyhow::{Context, Result}; +#[cfg(feature = "online")] +use anyhow::Context; +use anyhow::Result; #[cfg(feature = "online")] use hf_hub::{ api::sync::{ApiBuilder, ApiRepo}, diff --git a/src/text_embedding/impl.rs b/src/text_embedding/impl.rs index fa510a4..213a86c 100644 --- a/src/text_embedding/impl.rs +++ b/src/text_embedding/impl.rs @@ -8,7 +8,9 @@ use crate::{ pooling::Pooling, Embedding, EmbeddingModel, EmbeddingOutput, ModelInfo, QuantizationMode, SingleBatchOutput, }; -use anyhow::{Context, Result}; +#[cfg(feature = "online")] +use anyhow::Context; +use anyhow::Result; #[cfg(feature = "online")] use hf_hub::{ api::sync::{ApiBuilder, ApiRepo}, From d51dea8bf599ed5cd044968ee84b68fb12e8671b Mon Sep 17 00:00:00 2001 From: Lia Lenckowski Date: Tue, 26 Nov 2024 10:07:52 +0100 Subject: [PATCH 6/6] cargo fmt --- src/image_embedding/impl.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/image_embedding/impl.rs b/src/image_embedding/impl.rs index d614971..104f85e 100644 --- a/src/image_embedding/impl.rs +++ b/src/image_embedding/impl.rs @@ -16,9 +16,9 @@ use crate::{ common::normalize, models::image_embedding::models_list, Embedding, ImageEmbeddingModel, ModelInfo, }; +use anyhow::anyhow; #[cfg(feature = "online")] use anyhow::Context; -use anyhow::anyhow; #[cfg(feature = "online")] use super::ImageInitOptions;