Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Turn most panics into recoverable errors #128

Merged
merged 6 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/image_embedding/impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ use crate::{
ModelInfo,
};
use anyhow::anyhow;
#[cfg(feature = "online")]
use anyhow::Context;

#[cfg(feature = "online")]
use super::ImageInitOptions;
Expand Down Expand Up @@ -52,13 +54,13 @@ impl ImageEmbedding {

let preprocessor_file = model_repo
.get("preprocessor_config.json")
.unwrap_or_else(|_| panic!("Failed to retrieve preprocessor_config.json"));
.context("Failed to retrieve preprocessor_config.json")?;
let preprocessor = Compose::from_file(preprocessor_file)?;

let model_file_name = ImageEmbedding::get_model_info(&model_name).model_file;
let model_file_reference = model_repo
.get(&model_file_name)
.unwrap_or_else(|_| panic!("Failed to retrieve {} ", model_file_name));
.context(format!("Failed to retrieve {}", model_file_name))?;

let session = Session::builder()?
.with_execution_providers(execution_providers)?
Expand Down Expand Up @@ -111,8 +113,7 @@ impl ImageEmbedding {
let cache = Cache::new(cache_dir);
let api = ApiBuilder::from_cache(cache)
.with_progress(show_download_progress)
.build()
.unwrap();
.build()?;

let repo = api.model(model.to_string());
Ok(repo)
Expand Down Expand Up @@ -189,7 +190,9 @@ impl ImageEmbedding {

Ok(embeddings)
})
.flat_map(|result: Result<Vec<Vec<f32>>, anyhow::Error>| result.unwrap())
.collect::<anyhow::Result<Vec<_>>>()?
.into_iter()
.flatten()
.collect();

Ok(output)
Expand Down
21 changes: 13 additions & 8 deletions src/reranking/impl.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#[cfg(feature = "online")]
use anyhow::Context;
use anyhow::Result;
use ort::{
session::{builder::GraphOptimizationLevel, Session},
Expand Down Expand Up @@ -70,15 +72,16 @@ impl TextRerank {
let model_repo = api.model(model_name.to_string());

let model_file_name = TextRerank::get_model_info(&model_name).model_file;
let model_file_reference = model_repo
.get(&model_file_name)
.unwrap_or_else(|_| panic!("Failed to retrieve model file: {}", model_file_name));
let model_file_reference = model_repo.get(&model_file_name).context(format!(
"Failed to retrieve model file: {}",
model_file_name
))?;
let additional_files = TextRerank::get_model_info(&model_name).additional_files;
for additional_file in additional_files {
let _additional_file_reference =
model_repo.get(&additional_file).unwrap_or_else(|_| {
panic!("Failed to retrieve additional file: {}", additional_file)
});
let _additional_file_reference = model_repo.get(&additional_file).context(format!(
"Failed to retrieve additional file: {}",
additional_file
))?;
}

let session = Session::builder()?
Expand Down Expand Up @@ -196,7 +199,9 @@ impl TextRerank {

Ok(scores)
})
.flat_map(|result: Result<Vec<f32>, anyhow::Error>| result.unwrap())
.collect::<Result<Vec<_>>>()?
.into_iter()
.flatten()
.collect();

// Return top_n_result of type Vec<RerankResult> ordered by score in descending order, don't use binary heap
Expand Down
11 changes: 7 additions & 4 deletions src/sparse_text_embedding/impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use crate::{
models::sparse::{models_list, SparseModel},
ModelInfo, SparseEmbedding,
};
#[cfg(feature = "online")]
use anyhow::Context;
use anyhow::Result;
#[cfg(feature = "online")]
use hf_hub::{
Expand Down Expand Up @@ -55,7 +57,7 @@ impl SparseTextEmbedding {
let model_file_name = SparseTextEmbedding::get_model_info(&model_name).model_file;
let model_file_reference = model_repo
.get(&model_file_name)
.unwrap_or_else(|_| panic!("Failed to retrieve {} ", model_file_name));
.context(format!("Failed to retrieve {} ", model_file_name))?;

let session = Session::builder()?
.with_execution_providers(execution_providers)?
Expand Down Expand Up @@ -91,8 +93,7 @@ impl SparseTextEmbedding {
let cache = Cache::new(cache_dir);
let api = ApiBuilder::from_cache(cache)
.with_progress(show_download_progress)
.build()
.unwrap();
.build()?;

let repo = api.model(model.to_string());
Ok(repo)
Expand Down Expand Up @@ -189,7 +190,9 @@ impl SparseTextEmbedding {

Ok(embeddings)
})
.flat_map(|result: Result<Vec<SparseEmbedding>, anyhow::Error>| result.unwrap())
.collect::<Result<Vec<_>>>()?
.into_iter()
.flatten()
.collect();

Ok(output)
Expand Down
146 changes: 73 additions & 73 deletions src/text_embedding/impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ use crate::{
Embedding, EmbeddingModel, EmbeddingOutput, ModelInfo, QuantizationMode, SingleBatchOutput,
};
#[cfg(feature = "online")]
use anyhow::Context;
use anyhow::Result;
#[cfg(feature = "online")]
use hf_hub::{
api::sync::{ApiBuilder, ApiRepo},
Cache,
Expand Down Expand Up @@ -40,7 +43,7 @@ impl TextEmbedding {
///
/// Uses the total number of CPUs available as the number of intra-threads
#[cfg(feature = "online")]
pub fn try_new(options: InitOptions) -> anyhow::Result<Self> {
pub fn try_new(options: InitOptions) -> Result<Self> {
let InitOptions {
model_name,
execution_providers,
Expand All @@ -61,7 +64,7 @@ impl TextEmbedding {
let model_file_name = &model_info.model_file;
let model_file_reference = model_repo
.get(model_file_name)
.unwrap_or_else(|_| panic!("Failed to retrieve {} ", model_file_name));
.context(format!("Failed to retrieve {}", model_file_name))?;

// TODO: If more models need .onnx_data, implement a better way to handle this
// Probably by adding `additional_files` field in the `ModelInfo` struct
Expand Down Expand Up @@ -95,7 +98,7 @@ impl TextEmbedding {
pub fn try_new_from_user_defined(
model: UserDefinedEmbeddingModel,
options: InitOptionsUserDefined,
) -> anyhow::Result<Self> {
) -> Result<Self> {
let InitOptionsUserDefined {
execution_providers,
max_length,
Expand Down Expand Up @@ -147,8 +150,7 @@ impl TextEmbedding {
let cache = Cache::new(cache_dir);
let api = ApiBuilder::from_cache(cache)
.with_progress(show_download_progress)
.build()
.unwrap();
.build()?;

let repo = api.model(model.to_string());
Ok(repo)
Expand All @@ -160,7 +162,7 @@ impl TextEmbedding {
}

/// Get ModelInfo from EmbeddingModel
pub fn get_model_info(model: &EmbeddingModel) -> anyhow::Result<&ModelInfo<EmbeddingModel>> {
pub fn get_model_info(model: &EmbeddingModel) -> Result<&ModelInfo<EmbeddingModel>> {
get_model_info(model).ok_or_else(|| {
anyhow::Error::msg(format!(
"Model {model:?} not found. Please check if the model is supported \
Expand Down Expand Up @@ -195,7 +197,7 @@ impl TextEmbedding {
&'e self,
texts: Vec<S>,
batch_size: Option<usize>,
) -> anyhow::Result<EmbeddingOutput<'r, 's>>
) -> Result<EmbeddingOutput<'r, 's>>
where
'e: 'r,
'e: 's,
Expand Down Expand Up @@ -223,72 +225,70 @@ impl TextEmbedding {
_ => Ok(batch_size.unwrap_or(DEFAULT_BATCH_SIZE)),
}?;

let batches =
anyhow::Result::<Vec<_>>::from_par_iter(texts.par_chunks(batch_size).map(|batch| {
// Encode the texts in the batch
let inputs = batch.iter().map(|text| text.as_ref()).collect();
let encodings = self.tokenizer.encode_batch(inputs, true).map_err(|e| {
anyhow::Error::msg(e.to_string()).context("Failed to encode the batch.")
})?;

// Extract the encoding length and batch size
let encoding_length = encodings[0].len();
let batch_size = batch.len();

let max_size = encoding_length * batch_size;

// Preallocate arrays with the maximum size
let mut ids_array = Vec::with_capacity(max_size);
let mut mask_array = Vec::with_capacity(max_size);
let mut typeids_array = Vec::with_capacity(max_size);

// Not using par_iter because the closure needs to be FnMut
encodings.iter().for_each(|encoding| {
let ids = encoding.get_ids();
let mask = encoding.get_attention_mask();
let typeids = encoding.get_type_ids();

// Extend the preallocated arrays with the current encoding
// Requires the closure to be FnMut
ids_array.extend(ids.iter().map(|x| *x as i64));
mask_array.extend(mask.iter().map(|x| *x as i64));
typeids_array.extend(typeids.iter().map(|x| *x as i64));
});

// Create CowArrays from vectors
let inputs_ids_array =
Array::from_shape_vec((batch_size, encoding_length), ids_array)?;

let attention_mask_array =
Array::from_shape_vec((batch_size, encoding_length), mask_array)?;

let token_type_ids_array =
Array::from_shape_vec((batch_size, encoding_length), typeids_array)?;

let mut session_inputs = ort::inputs![
"input_ids" => Value::from_array(inputs_ids_array)?,
"attention_mask" => Value::from_array(attention_mask_array.view())?,
]?;

if self.need_token_type_ids {
session_inputs.push((
"token_type_ids".into(),
Value::from_array(token_type_ids_array)?.into(),
));
}
let batches = Result::<Vec<_>>::from_par_iter(texts.par_chunks(batch_size).map(|batch| {
// Encode the texts in the batch
let inputs = batch.iter().map(|text| text.as_ref()).collect();
let encodings = self.tokenizer.encode_batch(inputs, true).map_err(|e| {
anyhow::Error::msg(e.to_string()).context("Failed to encode the batch.")
})?;

// Extract the encoding length and batch size
let encoding_length = encodings[0].len();
let batch_size = batch.len();

let max_size = encoding_length * batch_size;

// Preallocate arrays with the maximum size
let mut ids_array = Vec::with_capacity(max_size);
let mut mask_array = Vec::with_capacity(max_size);
let mut typeids_array = Vec::with_capacity(max_size);

// Not using par_iter because the closure needs to be FnMut
encodings.iter().for_each(|encoding| {
let ids = encoding.get_ids();
let mask = encoding.get_attention_mask();
let typeids = encoding.get_type_ids();

// Extend the preallocated arrays with the current encoding
// Requires the closure to be FnMut
ids_array.extend(ids.iter().map(|x| *x as i64));
mask_array.extend(mask.iter().map(|x| *x as i64));
typeids_array.extend(typeids.iter().map(|x| *x as i64));
});

// Create CowArrays from vectors
let inputs_ids_array = Array::from_shape_vec((batch_size, encoding_length), ids_array)?;

let attention_mask_array =
Array::from_shape_vec((batch_size, encoding_length), mask_array)?;

let token_type_ids_array =
Array::from_shape_vec((batch_size, encoding_length), typeids_array)?;

let mut session_inputs = ort::inputs![
"input_ids" => Value::from_array(inputs_ids_array)?,
"attention_mask" => Value::from_array(attention_mask_array.view())?,
]?;

if self.need_token_type_ids {
session_inputs.push((
"token_type_ids".into(),
Value::from_array(token_type_ids_array)?.into(),
));
}

Ok(
// Package all the data required for post-processing (e.g. pooling)
// into a SingleBatchOutput struct.
SingleBatchOutput {
session_outputs: self
.session
.run(session_inputs)
.map_err(anyhow::Error::new)?,
attention_mask_array,
},
)
}))?;
Ok(
// Package all the data required for post-processing (e.g. pooling)
// into a SingleBatchOutput struct.
SingleBatchOutput {
session_outputs: self
.session
.run(session_inputs)
.map_err(anyhow::Error::new)?,
attention_mask_array,
},
)
}))?;

Ok(EmbeddingOutput::new(batches))
}
Expand All @@ -308,7 +308,7 @@ impl TextEmbedding {
&self,
texts: Vec<S>,
batch_size: Option<usize>,
) -> anyhow::Result<Vec<Embedding>> {
) -> Result<Vec<Embedding>> {
let batches = self.transform(texts, batch_size)?;

batches.export_with_transformer(output::transformer_with_precedence(
Expand Down
Loading