diff --git a/src/image_embedding/impl.rs b/src/image_embedding/impl.rs index a01c3a1..104f85e 100644 --- a/src/image_embedding/impl.rs +++ b/src/image_embedding/impl.rs @@ -17,6 +17,8 @@ use crate::{ ModelInfo, }; use anyhow::anyhow; +#[cfg(feature = "online")] +use anyhow::Context; #[cfg(feature = "online")] use super::ImageInitOptions; @@ -52,13 +54,13 @@ impl ImageEmbedding { let preprocessor_file = model_repo .get("preprocessor_config.json") - .unwrap_or_else(|_| panic!("Failed to retrieve preprocessor_config.json")); + .context("Failed to retrieve preprocessor_config.json")?; let preprocessor = Compose::from_file(preprocessor_file)?; let model_file_name = ImageEmbedding::get_model_info(&model_name).model_file; let model_file_reference = model_repo .get(&model_file_name) - .unwrap_or_else(|_| panic!("Failed to retrieve {} ", model_file_name)); + .context(format!("Failed to retrieve {}", model_file_name))?; let session = Session::builder()? .with_execution_providers(execution_providers)? @@ -111,8 +113,7 @@ impl ImageEmbedding { let cache = Cache::new(cache_dir); let api = ApiBuilder::from_cache(cache) .with_progress(show_download_progress) - .build() - .unwrap(); + .build()?; let repo = api.model(model.to_string()); Ok(repo) @@ -189,7 +190,9 @@ impl ImageEmbedding { Ok(embeddings) }) - .flat_map(|result: Result>, anyhow::Error>| result.unwrap()) + .collect::>>()? + .into_iter() + .flatten() .collect(); Ok(output) diff --git a/src/reranking/impl.rs b/src/reranking/impl.rs index 853bfc6..cc458fb 100644 --- a/src/reranking/impl.rs +++ b/src/reranking/impl.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "online")] +use anyhow::Context; use anyhow::Result; use ort::{ session::{builder::GraphOptimizationLevel, Session}, @@ -70,15 +72,16 @@ impl TextRerank { let model_repo = api.model(model_name.to_string()); let model_file_name = TextRerank::get_model_info(&model_name).model_file; - let model_file_reference = model_repo - .get(&model_file_name) - .unwrap_or_else(|_| panic!("Failed to retrieve model file: {}", model_file_name)); + let model_file_reference = model_repo.get(&model_file_name).context(format!( + "Failed to retrieve model file: {}", + model_file_name + ))?; let additional_files = TextRerank::get_model_info(&model_name).additional_files; for additional_file in additional_files { - let _additional_file_reference = - model_repo.get(&additional_file).unwrap_or_else(|_| { - panic!("Failed to retrieve additional file: {}", additional_file) - }); + let _additional_file_reference = model_repo.get(&additional_file).context(format!( + "Failed to retrieve additional file: {}", + additional_file + ))?; } let session = Session::builder()? @@ -196,7 +199,9 @@ impl TextRerank { Ok(scores) }) - .flat_map(|result: Result, anyhow::Error>| result.unwrap()) + .collect::>>()? + .into_iter() + .flatten() .collect(); // Return top_n_result of type Vec ordered by score in descending order, don't use binary heap diff --git a/src/sparse_text_embedding/impl.rs b/src/sparse_text_embedding/impl.rs index 444ade6..2bd4256 100644 --- a/src/sparse_text_embedding/impl.rs +++ b/src/sparse_text_embedding/impl.rs @@ -4,6 +4,8 @@ use crate::{ models::sparse::{models_list, SparseModel}, ModelInfo, SparseEmbedding, }; +#[cfg(feature = "online")] +use anyhow::Context; use anyhow::Result; #[cfg(feature = "online")] use hf_hub::{ @@ -55,7 +57,7 @@ impl SparseTextEmbedding { let model_file_name = SparseTextEmbedding::get_model_info(&model_name).model_file; let model_file_reference = model_repo .get(&model_file_name) - .unwrap_or_else(|_| panic!("Failed to retrieve {} ", model_file_name)); + .context(format!("Failed to retrieve {} ", model_file_name))?; let session = Session::builder()? .with_execution_providers(execution_providers)? @@ -91,8 +93,7 @@ impl SparseTextEmbedding { let cache = Cache::new(cache_dir); let api = ApiBuilder::from_cache(cache) .with_progress(show_download_progress) - .build() - .unwrap(); + .build()?; let repo = api.model(model.to_string()); Ok(repo) @@ -189,7 +190,9 @@ impl SparseTextEmbedding { Ok(embeddings) }) - .flat_map(|result: Result, anyhow::Error>| result.unwrap()) + .collect::>>()? + .into_iter() + .flatten() .collect(); Ok(output) diff --git a/src/text_embedding/impl.rs b/src/text_embedding/impl.rs index 7b15804..213a86c 100644 --- a/src/text_embedding/impl.rs +++ b/src/text_embedding/impl.rs @@ -9,6 +9,9 @@ use crate::{ Embedding, EmbeddingModel, EmbeddingOutput, ModelInfo, QuantizationMode, SingleBatchOutput, }; #[cfg(feature = "online")] +use anyhow::Context; +use anyhow::Result; +#[cfg(feature = "online")] use hf_hub::{ api::sync::{ApiBuilder, ApiRepo}, Cache, @@ -40,7 +43,7 @@ impl TextEmbedding { /// /// Uses the total number of CPUs available as the number of intra-threads #[cfg(feature = "online")] - pub fn try_new(options: InitOptions) -> anyhow::Result { + pub fn try_new(options: InitOptions) -> Result { let InitOptions { model_name, execution_providers, @@ -61,7 +64,7 @@ impl TextEmbedding { let model_file_name = &model_info.model_file; let model_file_reference = model_repo .get(model_file_name) - .unwrap_or_else(|_| panic!("Failed to retrieve {} ", model_file_name)); + .context(format!("Failed to retrieve {}", model_file_name))?; // TODO: If more models need .onnx_data, implement a better way to handle this // Probably by adding `additional_files` field in the `ModelInfo` struct @@ -95,7 +98,7 @@ impl TextEmbedding { pub fn try_new_from_user_defined( model: UserDefinedEmbeddingModel, options: InitOptionsUserDefined, - ) -> anyhow::Result { + ) -> Result { let InitOptionsUserDefined { execution_providers, max_length, @@ -147,8 +150,7 @@ impl TextEmbedding { let cache = Cache::new(cache_dir); let api = ApiBuilder::from_cache(cache) .with_progress(show_download_progress) - .build() - .unwrap(); + .build()?; let repo = api.model(model.to_string()); Ok(repo) @@ -160,7 +162,7 @@ impl TextEmbedding { } /// Get ModelInfo from EmbeddingModel - pub fn get_model_info(model: &EmbeddingModel) -> anyhow::Result<&ModelInfo> { + pub fn get_model_info(model: &EmbeddingModel) -> Result<&ModelInfo> { get_model_info(model).ok_or_else(|| { anyhow::Error::msg(format!( "Model {model:?} not found. Please check if the model is supported \ @@ -195,7 +197,7 @@ impl TextEmbedding { &'e self, texts: Vec, batch_size: Option, - ) -> anyhow::Result> + ) -> Result> where 'e: 'r, 'e: 's, @@ -223,72 +225,70 @@ impl TextEmbedding { _ => Ok(batch_size.unwrap_or(DEFAULT_BATCH_SIZE)), }?; - let batches = - anyhow::Result::>::from_par_iter(texts.par_chunks(batch_size).map(|batch| { - // Encode the texts in the batch - let inputs = batch.iter().map(|text| text.as_ref()).collect(); - let encodings = self.tokenizer.encode_batch(inputs, true).map_err(|e| { - anyhow::Error::msg(e.to_string()).context("Failed to encode the batch.") - })?; - - // Extract the encoding length and batch size - let encoding_length = encodings[0].len(); - let batch_size = batch.len(); - - let max_size = encoding_length * batch_size; - - // Preallocate arrays with the maximum size - let mut ids_array = Vec::with_capacity(max_size); - let mut mask_array = Vec::with_capacity(max_size); - let mut typeids_array = Vec::with_capacity(max_size); - - // Not using par_iter because the closure needs to be FnMut - encodings.iter().for_each(|encoding| { - let ids = encoding.get_ids(); - let mask = encoding.get_attention_mask(); - let typeids = encoding.get_type_ids(); - - // Extend the preallocated arrays with the current encoding - // Requires the closure to be FnMut - ids_array.extend(ids.iter().map(|x| *x as i64)); - mask_array.extend(mask.iter().map(|x| *x as i64)); - typeids_array.extend(typeids.iter().map(|x| *x as i64)); - }); - - // Create CowArrays from vectors - let inputs_ids_array = - Array::from_shape_vec((batch_size, encoding_length), ids_array)?; - - let attention_mask_array = - Array::from_shape_vec((batch_size, encoding_length), mask_array)?; - - let token_type_ids_array = - Array::from_shape_vec((batch_size, encoding_length), typeids_array)?; - - let mut session_inputs = ort::inputs![ - "input_ids" => Value::from_array(inputs_ids_array)?, - "attention_mask" => Value::from_array(attention_mask_array.view())?, - ]?; - - if self.need_token_type_ids { - session_inputs.push(( - "token_type_ids".into(), - Value::from_array(token_type_ids_array)?.into(), - )); - } + let batches = Result::>::from_par_iter(texts.par_chunks(batch_size).map(|batch| { + // Encode the texts in the batch + let inputs = batch.iter().map(|text| text.as_ref()).collect(); + let encodings = self.tokenizer.encode_batch(inputs, true).map_err(|e| { + anyhow::Error::msg(e.to_string()).context("Failed to encode the batch.") + })?; + + // Extract the encoding length and batch size + let encoding_length = encodings[0].len(); + let batch_size = batch.len(); + + let max_size = encoding_length * batch_size; + + // Preallocate arrays with the maximum size + let mut ids_array = Vec::with_capacity(max_size); + let mut mask_array = Vec::with_capacity(max_size); + let mut typeids_array = Vec::with_capacity(max_size); + + // Not using par_iter because the closure needs to be FnMut + encodings.iter().for_each(|encoding| { + let ids = encoding.get_ids(); + let mask = encoding.get_attention_mask(); + let typeids = encoding.get_type_ids(); + + // Extend the preallocated arrays with the current encoding + // Requires the closure to be FnMut + ids_array.extend(ids.iter().map(|x| *x as i64)); + mask_array.extend(mask.iter().map(|x| *x as i64)); + typeids_array.extend(typeids.iter().map(|x| *x as i64)); + }); + + // Create CowArrays from vectors + let inputs_ids_array = Array::from_shape_vec((batch_size, encoding_length), ids_array)?; + + let attention_mask_array = + Array::from_shape_vec((batch_size, encoding_length), mask_array)?; + + let token_type_ids_array = + Array::from_shape_vec((batch_size, encoding_length), typeids_array)?; + + let mut session_inputs = ort::inputs![ + "input_ids" => Value::from_array(inputs_ids_array)?, + "attention_mask" => Value::from_array(attention_mask_array.view())?, + ]?; + + if self.need_token_type_ids { + session_inputs.push(( + "token_type_ids".into(), + Value::from_array(token_type_ids_array)?.into(), + )); + } - Ok( - // Package all the data required for post-processing (e.g. pooling) - // into a SingleBatchOutput struct. - SingleBatchOutput { - session_outputs: self - .session - .run(session_inputs) - .map_err(anyhow::Error::new)?, - attention_mask_array, - }, - ) - }))?; + Ok( + // Package all the data required for post-processing (e.g. pooling) + // into a SingleBatchOutput struct. + SingleBatchOutput { + session_outputs: self + .session + .run(session_inputs) + .map_err(anyhow::Error::new)?, + attention_mask_array, + }, + ) + }))?; Ok(EmbeddingOutput::new(batches)) } @@ -308,7 +308,7 @@ impl TextEmbedding { &self, texts: Vec, batch_size: Option, - ) -> anyhow::Result> { + ) -> Result> { let batches = self.transform(texts, batch_size)?; batches.export_with_transformer(output::transformer_with_precedence(