Skip to content

Commit

Permalink
refactor: Turn most panics into recoverable errors (#128)
Browse files Browse the repository at this point in the history
* for image embeddings, turn some panics into returned errors

* turn more panics into recoverable errors

* Simplify collection of image embeddings

Co-authored-by: Anush  <anushshetty90@gmail.com>

* Pass errors in sparse text embedding and reranking batches onto the user instead of panicking

* fix warnings: only import anyhow::Context if online feature enabled

* cargo fmt

---------

Co-authored-by: Anush <anushshetty90@gmail.com>
  • Loading branch information
llenck and Anush008 authored Nov 26, 2024
1 parent 49a2185 commit 86594fc
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 90 deletions.
13 changes: 8 additions & 5 deletions src/image_embedding/impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ use crate::{
ModelInfo,
};
use anyhow::anyhow;
#[cfg(feature = "online")]
use anyhow::Context;

#[cfg(feature = "online")]
use super::ImageInitOptions;
Expand Down Expand Up @@ -52,13 +54,13 @@ impl ImageEmbedding {

let preprocessor_file = model_repo
.get("preprocessor_config.json")
.unwrap_or_else(|_| panic!("Failed to retrieve preprocessor_config.json"));
.context("Failed to retrieve preprocessor_config.json")?;
let preprocessor = Compose::from_file(preprocessor_file)?;

let model_file_name = ImageEmbedding::get_model_info(&model_name).model_file;
let model_file_reference = model_repo
.get(&model_file_name)
.unwrap_or_else(|_| panic!("Failed to retrieve {} ", model_file_name));
.context(format!("Failed to retrieve {}", model_file_name))?;

let session = Session::builder()?
.with_execution_providers(execution_providers)?
Expand Down Expand Up @@ -111,8 +113,7 @@ impl ImageEmbedding {
let cache = Cache::new(cache_dir);
let api = ApiBuilder::from_cache(cache)
.with_progress(show_download_progress)
.build()
.unwrap();
.build()?;

let repo = api.model(model.to_string());
Ok(repo)
Expand Down Expand Up @@ -189,7 +190,9 @@ impl ImageEmbedding {

Ok(embeddings)
})
.flat_map(|result: Result<Vec<Vec<f32>>, anyhow::Error>| result.unwrap())
.collect::<anyhow::Result<Vec<_>>>()?
.into_iter()
.flatten()
.collect();

Ok(output)
Expand Down
21 changes: 13 additions & 8 deletions src/reranking/impl.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#[cfg(feature = "online")]
use anyhow::Context;
use anyhow::Result;
use ort::{
session::{builder::GraphOptimizationLevel, Session},
Expand Down Expand Up @@ -70,15 +72,16 @@ impl TextRerank {
let model_repo = api.model(model_name.to_string());

let model_file_name = TextRerank::get_model_info(&model_name).model_file;
let model_file_reference = model_repo
.get(&model_file_name)
.unwrap_or_else(|_| panic!("Failed to retrieve model file: {}", model_file_name));
let model_file_reference = model_repo.get(&model_file_name).context(format!(
"Failed to retrieve model file: {}",
model_file_name
))?;
let additional_files = TextRerank::get_model_info(&model_name).additional_files;
for additional_file in additional_files {
let _additional_file_reference =
model_repo.get(&additional_file).unwrap_or_else(|_| {
panic!("Failed to retrieve additional file: {}", additional_file)
});
let _additional_file_reference = model_repo.get(&additional_file).context(format!(
"Failed to retrieve additional file: {}",
additional_file
))?;
}

let session = Session::builder()?
Expand Down Expand Up @@ -196,7 +199,9 @@ impl TextRerank {

Ok(scores)
})
.flat_map(|result: Result<Vec<f32>, anyhow::Error>| result.unwrap())
.collect::<Result<Vec<_>>>()?
.into_iter()
.flatten()
.collect();

// Return top_n_result of type Vec<RerankResult> ordered by score in descending order, don't use binary heap
Expand Down
11 changes: 7 additions & 4 deletions src/sparse_text_embedding/impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use crate::{
models::sparse::{models_list, SparseModel},
ModelInfo, SparseEmbedding,
};
#[cfg(feature = "online")]
use anyhow::Context;
use anyhow::Result;
#[cfg(feature = "online")]
use hf_hub::{
Expand Down Expand Up @@ -55,7 +57,7 @@ impl SparseTextEmbedding {
let model_file_name = SparseTextEmbedding::get_model_info(&model_name).model_file;
let model_file_reference = model_repo
.get(&model_file_name)
.unwrap_or_else(|_| panic!("Failed to retrieve {} ", model_file_name));
.context(format!("Failed to retrieve {} ", model_file_name))?;

let session = Session::builder()?
.with_execution_providers(execution_providers)?
Expand Down Expand Up @@ -91,8 +93,7 @@ impl SparseTextEmbedding {
let cache = Cache::new(cache_dir);
let api = ApiBuilder::from_cache(cache)
.with_progress(show_download_progress)
.build()
.unwrap();
.build()?;

let repo = api.model(model.to_string());
Ok(repo)
Expand Down Expand Up @@ -189,7 +190,9 @@ impl SparseTextEmbedding {

Ok(embeddings)
})
.flat_map(|result: Result<Vec<SparseEmbedding>, anyhow::Error>| result.unwrap())
.collect::<Result<Vec<_>>>()?
.into_iter()
.flatten()
.collect();

Ok(output)
Expand Down
146 changes: 73 additions & 73 deletions src/text_embedding/impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ use crate::{
Embedding, EmbeddingModel, EmbeddingOutput, ModelInfo, QuantizationMode, SingleBatchOutput,
};
#[cfg(feature = "online")]
use anyhow::Context;
use anyhow::Result;
#[cfg(feature = "online")]
use hf_hub::{
api::sync::{ApiBuilder, ApiRepo},
Cache,
Expand Down Expand Up @@ -40,7 +43,7 @@ impl TextEmbedding {
///
/// Uses the total number of CPUs available as the number of intra-threads
#[cfg(feature = "online")]
pub fn try_new(options: InitOptions) -> anyhow::Result<Self> {
pub fn try_new(options: InitOptions) -> Result<Self> {
let InitOptions {
model_name,
execution_providers,
Expand All @@ -61,7 +64,7 @@ impl TextEmbedding {
let model_file_name = &model_info.model_file;
let model_file_reference = model_repo
.get(model_file_name)
.unwrap_or_else(|_| panic!("Failed to retrieve {} ", model_file_name));
.context(format!("Failed to retrieve {}", model_file_name))?;

// TODO: If more models need .onnx_data, implement a better way to handle this
// Probably by adding `additional_files` field in the `ModelInfo` struct
Expand Down Expand Up @@ -95,7 +98,7 @@ impl TextEmbedding {
pub fn try_new_from_user_defined(
model: UserDefinedEmbeddingModel,
options: InitOptionsUserDefined,
) -> anyhow::Result<Self> {
) -> Result<Self> {
let InitOptionsUserDefined {
execution_providers,
max_length,
Expand Down Expand Up @@ -147,8 +150,7 @@ impl TextEmbedding {
let cache = Cache::new(cache_dir);
let api = ApiBuilder::from_cache(cache)
.with_progress(show_download_progress)
.build()
.unwrap();
.build()?;

let repo = api.model(model.to_string());
Ok(repo)
Expand All @@ -160,7 +162,7 @@ impl TextEmbedding {
}

/// Get ModelInfo from EmbeddingModel
pub fn get_model_info(model: &EmbeddingModel) -> anyhow::Result<&ModelInfo<EmbeddingModel>> {
pub fn get_model_info(model: &EmbeddingModel) -> Result<&ModelInfo<EmbeddingModel>> {
get_model_info(model).ok_or_else(|| {
anyhow::Error::msg(format!(
"Model {model:?} not found. Please check if the model is supported \
Expand Down Expand Up @@ -195,7 +197,7 @@ impl TextEmbedding {
&'e self,
texts: Vec<S>,
batch_size: Option<usize>,
) -> anyhow::Result<EmbeddingOutput<'r, 's>>
) -> Result<EmbeddingOutput<'r, 's>>
where
'e: 'r,
'e: 's,
Expand Down Expand Up @@ -223,72 +225,70 @@ impl TextEmbedding {
_ => Ok(batch_size.unwrap_or(DEFAULT_BATCH_SIZE)),
}?;

let batches =
anyhow::Result::<Vec<_>>::from_par_iter(texts.par_chunks(batch_size).map(|batch| {
// Encode the texts in the batch
let inputs = batch.iter().map(|text| text.as_ref()).collect();
let encodings = self.tokenizer.encode_batch(inputs, true).map_err(|e| {
anyhow::Error::msg(e.to_string()).context("Failed to encode the batch.")
})?;

// Extract the encoding length and batch size
let encoding_length = encodings[0].len();
let batch_size = batch.len();

let max_size = encoding_length * batch_size;

// Preallocate arrays with the maximum size
let mut ids_array = Vec::with_capacity(max_size);
let mut mask_array = Vec::with_capacity(max_size);
let mut typeids_array = Vec::with_capacity(max_size);

// Not using par_iter because the closure needs to be FnMut
encodings.iter().for_each(|encoding| {
let ids = encoding.get_ids();
let mask = encoding.get_attention_mask();
let typeids = encoding.get_type_ids();

// Extend the preallocated arrays with the current encoding
// Requires the closure to be FnMut
ids_array.extend(ids.iter().map(|x| *x as i64));
mask_array.extend(mask.iter().map(|x| *x as i64));
typeids_array.extend(typeids.iter().map(|x| *x as i64));
});

// Create CowArrays from vectors
let inputs_ids_array =
Array::from_shape_vec((batch_size, encoding_length), ids_array)?;

let attention_mask_array =
Array::from_shape_vec((batch_size, encoding_length), mask_array)?;

let token_type_ids_array =
Array::from_shape_vec((batch_size, encoding_length), typeids_array)?;

let mut session_inputs = ort::inputs![
"input_ids" => Value::from_array(inputs_ids_array)?,
"attention_mask" => Value::from_array(attention_mask_array.view())?,
]?;

if self.need_token_type_ids {
session_inputs.push((
"token_type_ids".into(),
Value::from_array(token_type_ids_array)?.into(),
));
}
let batches = Result::<Vec<_>>::from_par_iter(texts.par_chunks(batch_size).map(|batch| {
// Encode the texts in the batch
let inputs = batch.iter().map(|text| text.as_ref()).collect();
let encodings = self.tokenizer.encode_batch(inputs, true).map_err(|e| {
anyhow::Error::msg(e.to_string()).context("Failed to encode the batch.")
})?;

// Extract the encoding length and batch size
let encoding_length = encodings[0].len();
let batch_size = batch.len();

let max_size = encoding_length * batch_size;

// Preallocate arrays with the maximum size
let mut ids_array = Vec::with_capacity(max_size);
let mut mask_array = Vec::with_capacity(max_size);
let mut typeids_array = Vec::with_capacity(max_size);

// Not using par_iter because the closure needs to be FnMut
encodings.iter().for_each(|encoding| {
let ids = encoding.get_ids();
let mask = encoding.get_attention_mask();
let typeids = encoding.get_type_ids();

// Extend the preallocated arrays with the current encoding
// Requires the closure to be FnMut
ids_array.extend(ids.iter().map(|x| *x as i64));
mask_array.extend(mask.iter().map(|x| *x as i64));
typeids_array.extend(typeids.iter().map(|x| *x as i64));
});

// Create CowArrays from vectors
let inputs_ids_array = Array::from_shape_vec((batch_size, encoding_length), ids_array)?;

let attention_mask_array =
Array::from_shape_vec((batch_size, encoding_length), mask_array)?;

let token_type_ids_array =
Array::from_shape_vec((batch_size, encoding_length), typeids_array)?;

let mut session_inputs = ort::inputs![
"input_ids" => Value::from_array(inputs_ids_array)?,
"attention_mask" => Value::from_array(attention_mask_array.view())?,
]?;

if self.need_token_type_ids {
session_inputs.push((
"token_type_ids".into(),
Value::from_array(token_type_ids_array)?.into(),
));
}

Ok(
// Package all the data required for post-processing (e.g. pooling)
// into a SingleBatchOutput struct.
SingleBatchOutput {
session_outputs: self
.session
.run(session_inputs)
.map_err(anyhow::Error::new)?,
attention_mask_array,
},
)
}))?;
Ok(
// Package all the data required for post-processing (e.g. pooling)
// into a SingleBatchOutput struct.
SingleBatchOutput {
session_outputs: self
.session
.run(session_inputs)
.map_err(anyhow::Error::new)?,
attention_mask_array,
},
)
}))?;

Ok(EmbeddingOutput::new(batches))
}
Expand All @@ -308,7 +308,7 @@ impl TextEmbedding {
&self,
texts: Vec<S>,
batch_size: Option<usize>,
) -> anyhow::Result<Vec<Embedding>> {
) -> Result<Vec<Embedding>> {
let batches = self.transform(texts, batch_size)?;

batches.export_with_transformer(output::transformer_with_precedence(
Expand Down

0 comments on commit 86594fc

Please sign in to comment.