From 0024610ac1afc89b6d55efa402c53f0fdf703321 Mon Sep 17 00:00:00 2001 From: HabeebShopeju Date: Wed, 30 Oct 2024 02:36:25 +0000 Subject: [PATCH 01/15] Rough setup for ClipVit Text --- ahnlich/Cargo.lock | 39 ++- ahnlich/ai/Cargo.toml | 1 + ahnlich/ai/src/cli/server.rs | 20 +- ahnlich/ai/src/engine/ai/models.rs | 17 +- .../ai/src/engine/ai/providers/fastembed.rs | 41 +-- ahnlich/ai/src/engine/ai/providers/mod.rs | 1 + ahnlich/ai/src/engine/ai/providers/ort.rs | 276 ++++++++++++++---- .../ai/src/engine/ai/providers/ort_helper.rs | 119 ++++++++ ahnlich/ai/src/error.rs | 16 +- ahnlich/ai/src/manager/mod.rs | 28 +- ahnlich/dsl/src/ai.rs | 2 +- ahnlich/types/src/ai/mod.rs | 3 +- .../ahnlich_client_py/internals/ai_query.py | 10 +- .../internals/ai_response.py | 4 +- sdk/ahnlich-client-py/demo_embed.py | 8 +- type_specs/query/ai_query.json | 2 +- type_specs/response/ai_response.json | 2 +- 17 files changed, 476 insertions(+), 113 deletions(-) create mode 100644 ahnlich/ai/src/engine/ai/providers/ort_helper.rs diff --git a/ahnlich/Cargo.lock b/ahnlich/Cargo.lock index 79ec5947..8a7ca0ed 100644 --- a/ahnlich/Cargo.lock +++ b/ahnlich/Cargo.lock @@ -123,6 +123,7 @@ dependencies = [ "termcolor", "thiserror", "tiktoken-rs", + "tokenizers 0.20.1", "tokio", "tokio-util", "tracer", @@ -1092,6 +1093,9 @@ name = "esaxx-rs" version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6" +dependencies = [ + "cc", +] [[package]] name = "event-listener" @@ -1162,7 +1166,7 @@ dependencies = [ "ort", "rayon", "serde_json", - "tokenizers", + "tokenizers 0.19.1", ] [[package]] @@ -3510,6 +3514,39 @@ dependencies = [ "unicode_categories", ] +[[package]] +name = "tokenizers" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b172ffa9a2e5c31bbddc940cd5725d933ced983a9333bbebc4c7eda3bbce1557" +dependencies = [ + "aho-corasick", + "derive_builder", + "esaxx-rs", + "getrandom", + "hf-hub", + "indicatif", + "itertools 0.12.1", + "lazy_static", + "log", + "macro_rules_attribute", + "monostate", + "onig", + "paste", + "rand", + "rayon", + "rayon-cond", + "regex", + "regex-syntax 0.8.5", + "serde", + "serde_json", + "spm_precompiled", + "thiserror", + "unicode-normalization-alignments", + "unicode-segmentation", + "unicode_categories", +] + [[package]] name = "tokio" version = "1.41.0" diff --git a/ahnlich/ai/Cargo.toml b/ahnlich/ai/Cargo.toml index e07b09d9..1a38bf7c 100644 --- a/ahnlich/ai/Cargo.toml +++ b/ahnlich/ai/Cargo.toml @@ -47,6 +47,7 @@ tracing-opentelemetry.workspace = true futures.workspace = true tiktoken-rs = "0.5.9" itertools.workspace = true +tokenizers = { version = "0.20.1", features = ["hf-hub"] } [dev-dependencies] db = { path = "../db", version = "*" } pretty_assertions.workspace = true diff --git a/ahnlich/ai/src/cli/server.rs b/ahnlich/ai/src/cli/server.rs index 5dda4db0..50802f7f 100644 --- a/ahnlich/ai/src/cli/server.rs +++ b/ahnlich/ai/src/cli/server.rs @@ -22,8 +22,10 @@ pub enum SupportedModels { BGELargeEnV15, #[clap(name = "resnet-50")] Resnet50, - #[clap(name = "clip-vit-b32")] - ClipVitB32, + #[clap(name = "clip-vit-b32-image")] + ClipVitB32Image, + #[clap(name = "clip-vit-b32-text")] + ClipVitB32Text, } #[derive(Parser)] @@ -132,9 +134,10 @@ impl Default for AIProxyConfig { supported_models: vec![ SupportedModels::AllMiniLML6V2, SupportedModels::AllMiniLML12V2, - SupportedModels::Resnet50, - SupportedModels::ClipVitB32, SupportedModels::BGEBaseEnV15, + SupportedModels::ClipVitB32Text, + SupportedModels::Resnet50, + SupportedModels::ClipVitB32Image, ], model_cache_location: home_dir() .map(|mut path| { @@ -192,7 +195,8 @@ impl fmt::Display for SupportedModels { SupportedModels::BGEBaseEnV15 => write!(f, "BGEBase-En-v1.5"), SupportedModels::BGELargeEnV15 => write!(f, "BGELarge-En-v1.5"), SupportedModels::Resnet50 => write!(f, "Resnet-50"), - SupportedModels::ClipVitB32 => write!(f, "ClipVit-B32"), + SupportedModels::ClipVitB32Image => write!(f, "ClipVit-B32-Image"), + SupportedModels::ClipVitB32Text => write!(f, "ClipVit-B32-Text"), } } } @@ -205,7 +209,8 @@ impl From<&AIModel> for SupportedModels { AIModel::BGEBaseEnV15 => SupportedModels::BGEBaseEnV15, AIModel::BGELargeEnV15 => SupportedModels::BGELargeEnV15, AIModel::Resnet50 => SupportedModels::Resnet50, - AIModel::ClipVitB32 => SupportedModels::ClipVitB32, + AIModel::ClipVitB32Image => SupportedModels::ClipVitB32Image, + AIModel::ClipVitB32Text => SupportedModels::ClipVitB32Text, } } } @@ -218,7 +223,8 @@ impl From<&SupportedModels> for AIModel { SupportedModels::BGEBaseEnV15 => AIModel::BGEBaseEnV15, SupportedModels::BGELargeEnV15 => AIModel::BGELargeEnV15, SupportedModels::Resnet50 => AIModel::Resnet50, - SupportedModels::ClipVitB32 => AIModel::ClipVitB32, + SupportedModels::ClipVitB32Image => AIModel::ClipVitB32Image, + SupportedModels::ClipVitB32Text => AIModel::ClipVitB32Text, } } } diff --git a/ahnlich/ai/src/engine/ai/models.rs b/ahnlich/ai/src/engine/ai/models.rs index 6a0afc66..f430ef35 100644 --- a/ahnlich/ai/src/engine/ai/models.rs +++ b/ahnlich/ai/src/engine/ai/models.rs @@ -92,17 +92,30 @@ impl From<&AIModel> for Model { description: String::from("Residual Networks model, with 50 layers."), embedding_size: nonzero!(2048usize), }, - AIModel::ClipVitB32 => Self { + AIModel::ClipVitB32Image => Self { model_type: ModelType::Image { expected_image_dimensions: (nonzero!(224usize), nonzero!(224usize)), }, provider: ModelProviders::ORT(ORTProvider::new()), - supported_model: SupportedModels::ClipVitB32, + supported_model: SupportedModels::ClipVitB32Image, description: String::from( "Contrastive Language-Image Pre-Training Vision transformer model, base scale.", ), embedding_size: nonzero!(512usize), }, + AIModel::ClipVitB32Text => Self { + model_type: ModelType::Text { + // Token size source: https://github.com/UKPLab/sentence-transformers/issues/1269 + max_input_tokens: nonzero!(77usize), + }, + provider: ModelProviders::ORT(ORTProvider::new()), + supported_model: SupportedModels::ClipVitB32Text, + description: String::from( + "Contrastive Language-Image Pre-Training Text transformer model, base scale. \ + Ideal for embedding very short text and using in combination with ClipVitB32Image", + ), + embedding_size: nonzero!(512usize), + }, } } } diff --git a/ahnlich/ai/src/engine/ai/providers/fastembed.rs b/ahnlich/ai/src/engine/ai/providers/fastembed.rs index f97f43b6..98e849c8 100644 --- a/ahnlich/ai/src/engine/ai/providers/fastembed.rs +++ b/ahnlich/ai/src/engine/ai/providers/fastembed.rs @@ -40,17 +40,20 @@ pub struct FastEmbedPreprocessor { } // TODO (HAKSOAT): Implement other preprocessors -impl TextPreprocessorTrait for FastEmbedPreprocessor { +impl TextPreprocessorTrait for FastEmbedProvider { fn encode_str(&self, text: &str) -> Result, AIProxyError> { - let tokens = self.tokenizer.0.encode_with_special_tokens(text); + let preprocessor = self.preprocessor.as_ref() + .ok_or(AIProxyError::AIModelNotInitialized)?; + let tokens = preprocessor + .tokenizer.0.encode_with_special_tokens(text); Ok(tokens) } fn decode_tokens(&self, tokens: Vec) -> Result { - let text = self - .tokenizer - .0 - .decode(tokens) + let preprocessor = self.preprocessor.as_ref() + .ok_or(AIProxyError::AIModelNotInitialized)?; + let text = preprocessor + .tokenizer.0.decode(tokens) .map_err(|_| AIProxyError::ModelTokenizationError)?; Ok(text) } @@ -87,11 +90,11 @@ impl TryFrom<&SupportedModels> for FastEmbedModelType { SupportedModels::AllMiniLML12V2 => EmbeddingModel::AllMiniLML12V2, SupportedModels::BGEBaseEnV15 => EmbeddingModel::BGEBaseENV15, SupportedModels::BGELargeEnV15 => EmbeddingModel::BGELargeENV15, - _ => return Err(AIProxyError::AIModelNotSupported), + _ => return Err(AIProxyError::AIModelNotSupported { model_name: model.to_string() }), }; FastEmbedModelType::Text(model_type) } - _ => return Err(AIProxyError::AIModelNotSupported), + _ => return Err(AIProxyError::AIModelNotSupported { model_name: model.to_string() }), }; Ok(model_type) } @@ -116,22 +119,6 @@ impl FastEmbedProvider { self.preprocessor = Some(FastEmbedPreprocessor { tokenizer }); Ok(()) } - - pub fn encode_str(&self, text: &str) -> Result, AIProxyError> { - if let Some(preprocessor) = &self.preprocessor { - preprocessor.encode_str(text) - } else { - Err(AIProxyError::AIModelNotInitialized) - } - } - - pub fn decode_tokens(&self, tokens: Vec) -> Result { - if let Some(preprocessor) = &self.preprocessor { - preprocessor.decode_tokens(tokens) - } else { - Err(AIProxyError::AIModelNotInitialized) - } - } } impl ProviderTrait for FastEmbedProvider { @@ -211,7 +198,7 @@ impl ProviderTrait for FastEmbedProvider { }); } let FastEmbedModel::Text(model) = fastembed_model else { - return Err(AIProxyError::AIModelNotSupported); + return Err(AIProxyError::AIModelNotSupported { model_name: self.supported_models.unwrap().to_string() }); }; let batch_size = 16; let store_keys = model @@ -224,7 +211,9 @@ impl ProviderTrait for FastEmbedProvider { }); store_keys } else { - Err(AIProxyError::AIModelNotSupported) + Err(AIProxyError::AIModelNotSupported { + model_name: self.supported_models.unwrap().to_string(), + }) }; } } diff --git a/ahnlich/ai/src/engine/ai/providers/mod.rs b/ahnlich/ai/src/engine/ai/providers/mod.rs index f2086f67..e00dc2f2 100644 --- a/ahnlich/ai/src/engine/ai/providers/mod.rs +++ b/ahnlich/ai/src/engine/ai/providers/mod.rs @@ -1,5 +1,6 @@ pub(crate) mod fastembed; pub(crate) mod ort; +mod ort_helper; use crate::cli::server::SupportedModels; use crate::engine::ai::models::{InputAction, ModelInput}; diff --git a/ahnlich/ai/src/engine/ai/providers/ort.rs b/ahnlich/ai/src/engine/ai/providers/ort.rs index 0b805272..6751672b 100644 --- a/ahnlich/ai/src/engine/ai/providers/ort.rs +++ b/ahnlich/ai/src/engine/ai/providers/ort.rs @@ -1,17 +1,19 @@ use crate::cli::server::SupportedModels; use crate::engine::ai::models::{ImageArray, InputAction, Model, ModelInput}; -use crate::engine::ai::providers::ProviderTrait; +use crate::engine::ai::providers::ort_helper::{get_tokenizer_artifacts_hf_hub, normalize, + load_tokenizer_artifacts_hf_hub}; +use crate::engine::ai::providers::{ProviderTrait, TextPreprocessorTrait}; use crate::error::AIProxyError; -use ahnlich_types::ai::AIStoreInputType; use fallible_collections::FallibleVec; use hf_hub::{api::sync::ApiBuilder, Cache}; use itertools::Itertools; -use ort::Session; use rayon::iter::Either; +use ort::{Session, Value}; +use tokenizers::Tokenizer; use rayon::prelude::*; use ahnlich_types::keyval::StoreKey; -use ndarray::{Array1, ArrayView, Axis, Ix3}; +use ndarray::{Array, Array1, ArrayView, Axis, Ix3}; use std::convert::TryFrom; use std::default::Default; use std::fmt; @@ -23,7 +25,8 @@ pub struct ORTProvider { cache_location: Option, cache_location_extension: PathBuf, supported_models: Option, - model: Option, + pub preprocessor: Option, + pub model: Option, } impl fmt::Debug for ORTProvider { @@ -41,36 +44,64 @@ pub struct ORTImageModel { repo_name: String, weights_file: String, session: Option, - input_param: String, + input_params: Vec, + output_param: String, +} + +#[derive(Default)] +pub struct ORTTextModel { + repo_name: String, + weights_file: String, + session: Option, + input_params: Vec, output_param: String, } pub enum ORTModel { Image(ORTImageModel), + Text(ORTTextModel) +} + +pub enum ORTPreprocessor { + Text(ORTTextPreprocessor), +} + +pub struct ORTTextPreprocessor { + tokenizer: Tokenizer, } impl TryFrom<&SupportedModels> for ORTModel { type Error = AIProxyError; fn try_from(model: &SupportedModels) -> Result { - let model_type = match model { - SupportedModels::Resnet50 => Ok(ORTImageModel { + let model_type: Result = match model { + SupportedModels::Resnet50 => Ok(ORTModel::Image(ORTImageModel { repo_name: "Qdrant/resnet50-onnx".to_string(), weights_file: "model.onnx".to_string(), - input_param: "input".to_string(), + input_params: vec!["input".to_string()], output_param: "image_embeds".to_string(), ..Default::default() - }), - SupportedModels::ClipVitB32 => Ok(ORTImageModel { + })), + SupportedModels::ClipVitB32Image => Ok(ORTModel::Image(ORTImageModel { repo_name: "Qdrant/clip-ViT-B-32-vision".to_string(), weights_file: "model.onnx".to_string(), - input_param: "pixel_values".to_string(), + input_params: vec!["pixel_values".to_string()], output_param: "image_embeds".to_string(), ..Default::default() + })), + SupportedModels::ClipVitB32Text => Ok(ORTModel::Text(ORTTextModel { + repo_name: "Qdrant/clip-ViT-B-32-text".to_string(), + weights_file: "model.onnx".to_string(), + input_params: vec!["input_ids".to_string(), "attention_mask".to_string()], + output_param: "text_embeds".to_string(), + ..Default::default() + })), + _ => Err(AIProxyError::AIModelNotSupported { + model_name: model.to_string(), }), - _ => Err(AIProxyError::AIModelNotSupported), }; - Ok(ORTModel::Image(model_type?)) + + model_type } } @@ -79,26 +110,16 @@ impl ORTProvider { Self { cache_location: None, cache_location_extension: PathBuf::from("huggingface"), + preprocessor: None, supported_models: None, model: None, } } - pub fn normalize(v: &[f32]) -> Vec { - let norm = (v.par_iter().map(|val| val * val).sum::()).sqrt(); - let epsilon = 1e-12; - - // We add the super-small epsilon to avoid dividing by zero - v.par_iter().map(|&val| val / (norm + epsilon)).collect() - } - - pub fn batch_inference( - &self, - mut inputs: Vec, - ) -> Result, AIProxyError> { + pub fn batch_inference_image(&self, mut inputs: Vec) -> Result, AIProxyError> { let model = match &self.model { Some(ORTModel::Image(model)) => model, - _ => return Err(AIProxyError::AIModelNotSupported), + _ => return Err(AIProxyError::AIModelNotSupported { model_name: self.supported_models.unwrap().to_string() }), }; let array_views: Vec> = inputs @@ -114,12 +135,96 @@ impl ORTProvider { match &model.session { Some(session) => { let session_inputs = ort::inputs![ - model.input_param.as_str() => pixel_values_array.view(), - ] - .map_err(|e| AIProxyError::ModelProviderPreprocessingError(e.to_string()))?; + model.input_params.get(0).expect("Hardcoded in parameters") + .as_str() => pixel_values_array.view(), + ].map_err(|e| AIProxyError::ModelProviderPreprocessingError(e.to_string()))?; + + let outputs = session.run(session_inputs) + .map_err(|e| AIProxyError::ModelProviderRunInferenceError(e.to_string()))?; + let last_hidden_state_key = match outputs.len() { + 1 => outputs.keys().next().unwrap(), + _ => model.output_param.as_str(), + }; + + let output_data = outputs[last_hidden_state_key] + .try_extract_tensor::() + .map_err(|e| AIProxyError::ModelProviderPostprocessingError(e.to_string()))?; + let store_keys = output_data + .axis_iter(Axis(0)) + .into_par_iter() + .map(|row| { + let embeddings = normalize(row.as_slice().unwrap()); + StoreKey(>::from(embeddings)) + }) + .collect(); + Ok(store_keys) + } + None => Err(AIProxyError::AIModelNotInitialized) + } + } + + pub fn batch_inference_text(&self, inputs: Vec) -> Result, AIProxyError> { + let inputs = inputs.iter().map(|x| x.as_str()).collect::>(); + let model = match &self.model { + Some(ORTModel::Text(model)) => model, + _ => return Err(AIProxyError::AIModelNotSupported { model_name: self.supported_models.unwrap().to_string() }), + }; + let batch_size = inputs.len(); + let encodings = match &self.preprocessor { + Some(ORTPreprocessor::Text(preprocessor)) => { + // TODO: We encode tokens at the preprocess step early in the workflow then also encode here. + // Find a way to store those encoded tokens for reuse here. + preprocessor.tokenizer.encode_batch(inputs, true).map_err(|_| { + AIProxyError::ModelTokenizationError + })? + } + _ => return Err(AIProxyError::AIModelNotInitialized) + }; + + // Extract the encoding length and batch size + let encoding_length = encodings[0].len(); + + let max_size = encoding_length * batch_size; + + // Preallocate arrays with the maximum size + let mut ids_array = Vec::with_capacity(max_size); + let mut mask_array = Vec::with_capacity(max_size); + let mut typeids_array = Vec::with_capacity(max_size); - let outputs = session - .run(session_inputs) + // Not using par_iter because the closure needs to be FnMut + encodings.iter().for_each(|encoding| { + let ids = encoding.get_ids(); + let mask = encoding.get_attention_mask(); + let typeids = encoding.get_type_ids(); + + // Extend the preallocated arrays with the current encoding + // Requires the closure to be FnMut + ids_array.extend(ids.iter().map(|x| *x as i64)); + mask_array.extend(mask.iter().map(|x| *x as i64)); + typeids_array.extend(typeids.iter().map(|x| *x as i64)); + }); + + // Create CowArrays from vectors + let inputs_ids_array = + Array::from_shape_vec((batch_size, encoding_length), ids_array) + .map_err(|e| { + AIProxyError::ModelProviderPreprocessingError(e.to_string()) + })?; + + let attention_mask_array = + Array::from_shape_vec((batch_size, encoding_length), mask_array).map_err(|e| { + AIProxyError::ModelProviderPreprocessingError(e.to_string()) + })?; + + match &model.session { + Some(session) => { + let session_inputs = ort::inputs![ + model.input_params.get(0) + .expect("Hardcoded in parameters").as_str() => Value::from_array(inputs_ids_array)?, + model.input_params.get(1) + .expect("Hardcoded in parameters").as_str() => Value::from_array(attention_mask_array.view())? + ].map_err(|e| AIProxyError::ModelProviderPreprocessingError(e.to_string()))?; + let outputs = session.run(session_inputs) .map_err(|e| AIProxyError::ModelProviderRunInferenceError(e.to_string()))?; let last_hidden_state_key = match outputs.len() { 1 => outputs @@ -136,7 +241,7 @@ impl ORTProvider { .axis_iter(Axis(0)) .into_par_iter() .map(|row| { - let embeddings = ORTProvider::normalize(row.as_slice().unwrap()); + let embeddings = normalize(row.as_slice().unwrap()); StoreKey(>::from(embeddings)) }) .collect(); @@ -147,6 +252,38 @@ impl ORTProvider { } } +impl TextPreprocessorTrait for ORTProvider { + fn encode_str(&self, text: &str) -> Result, AIProxyError> { + match &self.model { + Some(ORTModel::Text(model)) => model, + _ => return Err(AIProxyError::AIModelNotSupported { model_name: self.supported_models.unwrap().to_string() }), + }; + + let Some(ORTPreprocessor::Text(preprocessor)) = &self.preprocessor else { + return Err(AIProxyError::AIModelNotInitialized); + }; + + let tokens = preprocessor.tokenizer.encode(text, true) + .map_err(|_| {AIProxyError::ModelTokenizationError})?; + Ok(tokens.get_ids().iter().map(|x| *x as usize).collect()) + } + + fn decode_tokens(&self, tokens: Vec) -> Result { + match &self.model { + Some(ORTModel::Text(model)) => model, + _ => return Err(AIProxyError::AIModelNotSupported { model_name: self.supported_models.unwrap().to_string() }), + }; + + let Some(ORTPreprocessor::Text(preprocessor)) = &self.preprocessor else { + return Err(AIProxyError::AIModelNotInitialized); + }; + + let tokens = tokens.iter().map(|x| *x as u32).collect::>(); + Ok(preprocessor.tokenizer.decode(&tokens, true) + .map_err(|_| {AIProxyError::ModelTokenizationError})?) + } +} + impl ProviderTrait for ORTProvider { fn set_cache_location(&mut self, location: &Path) { self.cache_location = Some(location.join(self.cache_location_extension.clone())); @@ -179,7 +316,7 @@ impl ProviderTrait for ORTProvider { ORTModel::Image(ORTImageModel { weights_file, repo_name, - input_param, + input_params: input_param, output_param, .. }) => { @@ -193,10 +330,42 @@ impl ProviderTrait for ORTProvider { self.model = Some(ORTModel::Image(ORTImageModel { repo_name, weights_file, - input_param, + input_params: input_param, + output_param, + session: Some(session), + })); + }, + ORTModel::Text(ORTTextModel { + weights_file, + repo_name, + input_params, + output_param, + .. + }) => { + let model_repo = api.model(repo_name.clone()); + let model_file_reference = model_repo + .get(&weights_file) + .map_err(|e| AIProxyError::APIBuilderError(e.to_string()))?; + let session = Session::builder()? + .with_intra_threads(threads)? + .commit_from_file(model_file_reference)?; + let max_token_length = Model::from(&(self.supported_models + .ok_or(AIProxyError::AIModelNotInitialized)?)) + .max_input_token() + .ok_or(AIProxyError::AIModelNotInitialized)?; + let tokenizer = load_tokenizer_artifacts_hf_hub(&model_repo, + usize::from(max_token_length))?; + self.model = Some(ORTModel::Text(ORTTextModel { + repo_name, + weights_file, + input_params, output_param, session: Some(session), })); + self.preprocessor = Some(ORTPreprocessor::Text( + ORTTextPreprocessor { + tokenizer, + })); } } Ok(()) @@ -231,6 +400,14 @@ impl ProviderTrait for ORTProvider { .get("preprocessor_config.json") .map_err(|e| AIProxyError::APIBuilderError(e.to_string()))?; Ok(()) + }, + ORTModel::Text(ORTTextModel { + repo_name, + .. + }) => { + let model_repo = api.model(repo_name); + get_tokenizer_artifacts_hf_hub(&model_repo)?; + Ok(()) } } } @@ -238,7 +415,7 @@ impl ProviderTrait for ORTProvider { fn run_inference( &self, inputs: Vec, - action_type: &InputAction, + _action_type: &InputAction, ) -> Result, AIProxyError> { let (string_inputs, image_inputs): (Vec, Vec) = inputs.into_par_iter().partition_map(|input| match input { @@ -246,22 +423,21 @@ impl ProviderTrait for ORTProvider { ModelInput::Image(value) => Either::Right(value), }); - if !string_inputs.is_empty() { - let store_input_type: AIStoreInputType = AIStoreInputType::RawString; - let Some(index_model_repr) = self.supported_models else { - return Err(AIProxyError::AIModelNotInitialized); - }; - let index_model_repr: Model = (&index_model_repr).into(); - return Err(AIProxyError::StoreTypeMismatchError { - action: *action_type, - index_model_type: index_model_repr.input_type(), - storeinput_type: store_input_type, - }); + if !image_inputs.is_empty() && !string_inputs.is_empty() { + return Err(AIProxyError::VaryingInferenceInputTypes) } let batch_size = 16; - let mut store_keys: Vec<_> = FallibleVec::try_with_capacity(image_inputs.len())?; - for batch_inputs in image_inputs.into_iter().chunks(batch_size).into_iter() { - store_keys.extend(self.batch_inference(batch_inputs.collect())?); + let mut store_keys: Vec<_> = FallibleVec::try_with_capacity( + image_inputs.len().max(string_inputs.len()) + )?; + if !image_inputs.is_empty() { + for batch_inputs in image_inputs.into_iter().chunks(batch_size).into_iter() { + store_keys.extend(self.batch_inference_image(batch_inputs.collect())?); + } + } else { + for batch_inputs in string_inputs.into_iter().chunks(batch_size).into_iter() { + store_keys.extend(self.batch_inference_text(batch_inputs.collect())?); + } } Ok(store_keys) } diff --git a/ahnlich/ai/src/engine/ai/providers/ort_helper.rs b/ahnlich/ai/src/engine/ai/providers/ort_helper.rs new file mode 100644 index 00000000..f32ac23c --- /dev/null +++ b/ahnlich/ai/src/engine/ai/providers/ort_helper.rs @@ -0,0 +1,119 @@ +use hf_hub::api::sync::ApiRepo; +use std::io::Read; +use std::{fs::File, path::PathBuf}; +use tokenizers::{AddedToken, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams}; +use crate::error::AIProxyError; + + +// Tokenizer files for "bring your own" models +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TokenizerFiles { + pub tokenizer_file: Vec, + pub config_file: Vec, + pub special_tokens_map_file: Vec, + pub tokenizer_config_file: Vec, +} + +/// The procedure for loading tokenizer files from the hugging face hub is separated +/// from the main load_tokenizer function (which is expecting bytes, from any source). +pub fn load_tokenizer_artifacts_hf_hub(model_repo: &ApiRepo, max_length: usize) -> Result { + let tokenizer_files: TokenizerFiles = get_tokenizer_artifacts_hf_hub(model_repo)?; + load_tokenizer(tokenizer_files, max_length) +} + +pub fn get_tokenizer_artifacts_hf_hub(model_repo: &ApiRepo) -> Result { + Ok(TokenizerFiles { + tokenizer_file: read_file_to_bytes(&model_repo.get("tokenizer.json") + .map_err(|_| AIProxyError::ModelTokenizerLoadError)?) + .map_err(|_| AIProxyError::ModelTokenizerLoadError)?, + config_file: read_file_to_bytes(&model_repo.get("config.json") + .map_err(|_| AIProxyError::ModelTokenizerLoadError)?) + .map_err(|_| AIProxyError::ModelTokenizerLoadError)?, + special_tokens_map_file: read_file_to_bytes(&model_repo.get("special_tokens_map.json") + .map_err(|_| AIProxyError::ModelTokenizerLoadError)?) + .map_err(|_| AIProxyError::ModelTokenizerLoadError)?, + tokenizer_config_file: read_file_to_bytes(&model_repo.get("tokenizer_config.json") + .map_err(|_| AIProxyError::ModelTokenizerLoadError)?) + .map_err(|_| AIProxyError::ModelTokenizerLoadError)?, + }) +} + +/// Function can be called directly from the try_new_from_user_defined function (providing file bytes) +/// +/// Or indirectly from the try_new function via load_tokenizer_hf_hub (converting HF files to bytes) +pub fn load_tokenizer(tokenizer_files: TokenizerFiles, max_length: usize) -> Result { + // Serialise each tokenizer file + let config: serde_json::Value = + serde_json::from_slice(&tokenizer_files.config_file).map_err(|_| AIProxyError::ModelTokenizerLoadError)?; + let special_tokens_map: serde_json::Value = + serde_json::from_slice(&tokenizer_files.special_tokens_map_file).map_err(|_| AIProxyError::ModelTokenizerLoadError)?; + let tokenizer_config: serde_json::Value = + serde_json::from_slice(&tokenizer_files.tokenizer_config_file).map_err(|_| AIProxyError::ModelTokenizerLoadError)?; + let mut tokenizer: tokenizers::Tokenizer = + tokenizers::Tokenizer::from_bytes(tokenizer_files.tokenizer_file).map_err(|_| AIProxyError::ModelTokenizerLoadError)?; + + //For BGEBaseSmall, the model_max_length value is set to 1000000000000000019884624838656. Which fits in a f64 + let model_max_length = tokenizer_config["model_max_length"] + .as_f64() + .ok_or(AIProxyError::ModelTokenizerLoadError)? as f32; + let max_length = max_length.min(model_max_length as usize); + let pad_id = config["pad_token_id"].as_u64().unwrap_or(0) as u32; + let pad_token = tokenizer_config["pad_token"] + .as_str() + .ok_or(AIProxyError::ModelTokenizerLoadError)? + .into(); + + let mut tokenizer = tokenizer + .with_padding(Some(PaddingParams { + // TODO: the user should able to choose the padding strategy + strategy: PaddingStrategy::BatchLongest, + pad_token, + pad_id, + ..Default::default() + })) + .with_truncation(Some(TruncationParams { + max_length, + ..Default::default() + })) + .map_err(|_| AIProxyError::ModelTokenizerLoadError)? + .clone(); + if let serde_json::Value::Object(root_object) = special_tokens_map { + for (_, value) in root_object.iter() { + if value.is_string() { + tokenizer.add_special_tokens(&[AddedToken { + content: value.as_str().unwrap().into(), + special: true, + ..Default::default() + }]); + } else if value.is_object() { + tokenizer.add_special_tokens(&[AddedToken { + content: value["content"].as_str().unwrap().into(), + special: true, + single_word: value["single_word"].as_bool().unwrap(), + lstrip: value["lstrip"].as_bool().unwrap(), + rstrip: value["rstrip"].as_bool().unwrap(), + normalized: value["normalized"].as_bool().unwrap(), + }]); + } + } + } + Ok(tokenizer.into()) +} + +pub fn normalize(v: &[f32]) -> Vec { + let norm = (v.iter().map(|val| val * val).sum::()).sqrt(); + let epsilon = 1e-12; + + // We add the super-small epsilon to avoid dividing by zero + v.iter().map(|&val| val / (norm + epsilon)).collect() +} + +/// Public function to read a file to bytes. +/// To be used when loading local model files. +pub fn read_file_to_bytes(file: &PathBuf) -> Result, AIProxyError> { + let mut file = File::open(file).map_err(|_| AIProxyError::ModelTokenizerLoadError)?; + let file_size = file.metadata().map_err(|_| AIProxyError::ModelTokenizerLoadError)?.len() as usize; + let mut buffer = Vec::with_capacity(file_size); + file.read_to_end(&mut buffer).map_err(|_| AIProxyError::ModelTokenizerLoadError)?; + Ok(buffer) +} \ No newline at end of file diff --git a/ahnlich/ai/src/error.rs b/ahnlich/ai/src/error.rs index 35f4d8b2..97a6479a 100644 --- a/ahnlich/ai/src/error.rs +++ b/ahnlich/ai/src/error.rs @@ -69,8 +69,10 @@ pub enum AIProxyError { #[error("Cache location for model was not initialized")] CacheLocationNotInitiailized, - #[error("index_model or query_model not supported")] - AIModelNotSupported, + #[error("index_model or query_model [{model_name}] not supported")] + AIModelNotSupported { + model_name: String + }, // TODO: Add SendError from mpsc::Sender into this variant #[error("Error sending request to model thread")] @@ -79,7 +81,7 @@ pub enum AIProxyError { #[error("Error receiving response from model thread")] AIModelRecvError(#[from] RecvError), - #[error("Dimensions Mismatch between index [{index_model_dim}], and Query [{query_model_dim}] Models")] + #[error("Dimensions Mismatch between index [{index_model_dim}], and Query [{query_model_dim}] Models.")] DimensionsMismatchError { index_model_dim: usize, query_model_dim: usize, @@ -87,7 +89,7 @@ pub enum AIProxyError { #[error("allocation error {0:?}")] Allocation(TryReserveError), - #[error("Error initializing a model thread {0}")] + #[error("Error initializing a model thread {0}.")] ModelInitializationError(String), #[error("Bytes could not be successfully decoded into an image.")] @@ -104,6 +106,9 @@ pub enum AIProxyError { #[error("Model provider failed on preprocessing the input {0}")] ModelProviderPreprocessingError(String), + #[error("Inference can only be run on one of text or image inputs, not both.")] + VaryingInferenceInputTypes, + #[error("Model provider failed on running inference {0}")] ModelProviderRunInferenceError(String), @@ -115,6 +120,9 @@ pub enum AIProxyError { #[error("Cannot call DelKey on store with `store_original` as false")] DelKeyError, + + #[error("Tokenizer for model failed on loading.")] + ModelTokenizerLoadError } impl From for AIProxyError { diff --git a/ahnlich/ai/src/manager/mod.rs b/ahnlich/ai/src/manager/mod.rs index d7a3b3a0..2a1fc13d 100644 --- a/ahnlich/ai/src/manager/mod.rs +++ b/ahnlich/ai/src/manager/mod.rs @@ -7,7 +7,7 @@ use crate::engine::ai::models::{ImageArray, InputAction}; /// lets AIProxyTasks communicate with any model to receive immediate responses via a oneshot /// channel use crate::engine::ai::models::{Model, ModelInput}; -use crate::engine::ai::providers::ModelProviders; +use crate::engine::ai::providers::{ModelProviders, TextPreprocessorTrait}; use crate::error::AIProxyError; use ahnlich_types::ai::{AIModel, AIStoreInputType, ImageAction, PreprocessAction, StringAction}; use ahnlich_types::keyval::{StoreInput, StoreKey}; @@ -122,24 +122,30 @@ impl ModelThread { return Err(AIProxyError::TokenTruncationNotSupported); } - let ModelProviders::FastEmbed(provider) = &self.model.provider else { - return Err(AIProxyError::TokenTruncationNotSupported); - }; - - let tokens = provider.encode_str(&input)?; - - if tokens.len() > max_token_size.into() { - if let StringAction::ErrorIfTokensExceed = string_action { + let process = |provider: &dyn TextPreprocessorTrait, input| { + let mut tokens = provider.encode_str(input)?; + let max_token_size: usize = max_token_size.into(); + if (tokens.len() > max_token_size) && + (string_action == StringAction::ErrorIfTokensExceed) { return Err(AIProxyError::TokenExceededError { input_token_size: tokens.len(), - max_token_size: max_token_size.into(), + max_token_size, }); } else { + tokens.truncate(max_token_size); let processed_input = provider.decode_tokens(tokens)?; return Ok(processed_input); } }; - Ok(input) + + match &self.model.provider { + ModelProviders::FastEmbed(provider) => { + process(provider, &input) + }, + ModelProviders::ORT(provider) => { + process(provider, &input) + } + } } #[tracing::instrument(skip(self, input))] diff --git a/ahnlich/dsl/src/ai.rs b/ahnlich/dsl/src/ai.rs index f2967880..c15d875c 100644 --- a/ahnlich/dsl/src/ai.rs +++ b/ahnlich/dsl/src/ai.rs @@ -39,7 +39,7 @@ fn parse_to_ai_model(input: &str) -> Result { "bge-base-en-v1.5" => Ok(AIModel::BGEBaseEnV15), "bge-large-en-v1.5" => Ok(AIModel::BGELargeEnV15), "resnet-50" => Ok(AIModel::Resnet50), - "clip-vit-b32" => Ok(AIModel::ClipVitB32), + "clip-vit-b32-image" => Ok(AIModel::ClipVitB32Image), e => Err(DslError::UnsupportedAIModel(e.to_string())), } } diff --git a/ahnlich/types/src/ai/mod.rs b/ahnlich/types/src/ai/mod.rs index 4c4ec1b6..71f3d6cb 100644 --- a/ahnlich/types/src/ai/mod.rs +++ b/ahnlich/types/src/ai/mod.rs @@ -16,7 +16,8 @@ pub enum AIModel { BGEBaseEnV15, BGELargeEnV15, Resnet50, - ClipVitB32, + ClipVitB32Image, + ClipVitB32Text, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)] diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_query.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_query.py index ac3df9d6..37fb48e5 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_query.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_query.py @@ -51,10 +51,15 @@ class AIModel__Resnet50(AIModel): @dataclass(frozen=True) -class AIModel__ClipVitB32(AIModel): +class AIModel__ClipVitB32Image(AIModel): INDEX = 5 # type: int pass +@dataclass(frozen=True) +class AIModel__ClipVitB32Text(AIModel): + INDEX = 6 # type: int + pass + AIModel.VARIANTS = [ AIModel__AllMiniLML6V2, @@ -62,7 +67,8 @@ class AIModel__ClipVitB32(AIModel): AIModel__BGEBaseEnV15, AIModel__BGELargeEnV15, AIModel__Resnet50, - AIModel__ClipVitB32, + AIModel__ClipVitB32Image, + AIModel__ClipVitB32Text, ] diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_response.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_response.py index 30982c5d..9e8cda32 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_response.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_response.py @@ -51,7 +51,7 @@ class AIModel__Resnet50(AIModel): @dataclass(frozen=True) -class AIModel__ClipVitB32(AIModel): +class AIModel__ClipVitB32Image(AIModel): INDEX = 5 # type: int pass @@ -62,7 +62,7 @@ class AIModel__ClipVitB32(AIModel): AIModel__BGEBaseEnV15, AIModel__BGELargeEnV15, AIModel__Resnet50, - AIModel__ClipVitB32, + AIModel__ClipVitB32Image, ] diff --git a/sdk/ahnlich-client-py/demo_embed.py b/sdk/ahnlich-client-py/demo_embed.py index 817ab89b..c5c5c302 100644 --- a/sdk/ahnlich-client-py/demo_embed.py +++ b/sdk/ahnlich-client-py/demo_embed.py @@ -8,14 +8,14 @@ ai_store_payload_no_predicates = { "store_name": "Diretnan Stores", - "query_model": ai_query.AIModel__AllMiniLML6V2(), - "index_model": ai_query.AIModel__AllMiniLML6V2(), + "query_model": ai_query.AIModel__ClipVitB32Text(), + "index_model": ai_query.AIModel__ClipVitB32Text(), } ai_store_payload_with_predicates = { "store_name": "Diretnan Predication Stores", - "query_model": ai_query.AIModel__AllMiniLML6V2(), - "index_model": ai_query.AIModel__AllMiniLML6V2(), + "query_model": ai_query.AIModel__ClipVitB32Text(), + "index_model": ai_query.AIModel__ClipVitB32Text(), "predicates": ["special", "brand"], } diff --git a/type_specs/query/ai_query.json b/type_specs/query/ai_query.json index c2d9f2e2..d8184808 100644 --- a/type_specs/query/ai_query.json +++ b/type_specs/query/ai_query.json @@ -17,7 +17,7 @@ "Resnet50": "UNIT" }, "5": { - "ClipVitB32": "UNIT" + "ClipVitB32Image": "UNIT" } } }, diff --git a/type_specs/response/ai_response.json b/type_specs/response/ai_response.json index 632ec280..997acf1c 100644 --- a/type_specs/response/ai_response.json +++ b/type_specs/response/ai_response.json @@ -17,7 +17,7 @@ "Resnet50": "UNIT" }, "5": { - "ClipVitB32": "UNIT" + "ClipVitB32Image": "UNIT" } } }, From a192e728c10d61a2aeec636aeeb05734ef49adc2 Mon Sep 17 00:00:00 2001 From: HabeebShopeju Date: Wed, 30 Oct 2024 02:47:13 +0000 Subject: [PATCH 02/15] Ran clippy and typegen --- ahnlich/ai/src/engine/ai/providers/ort.rs | 8 +- .../ai/src/engine/ai/providers/ort_helper.rs | 8 +- ahnlich/ai/src/manager/mod.rs | 6 +- .../ahnlich_client_py/internals/ai_query.py | 81 +++++++---------- .../internals/ai_response.py | 86 ++++++++----------- .../internals/bincode/__init__.py | 4 +- .../ahnlich_client_py/internals/db_query.py | 45 ++++------ .../internals/db_response.py | 67 +++++++-------- .../internals/serde_binary/__init__.py | 2 +- .../internals/serde_types/__init__.py | 5 +- 10 files changed, 131 insertions(+), 181 deletions(-) diff --git a/ahnlich/ai/src/engine/ai/providers/ort.rs b/ahnlich/ai/src/engine/ai/providers/ort.rs index 6751672b..1c27ec70 100644 --- a/ahnlich/ai/src/engine/ai/providers/ort.rs +++ b/ahnlich/ai/src/engine/ai/providers/ort.rs @@ -135,7 +135,7 @@ impl ORTProvider { match &model.session { Some(session) => { let session_inputs = ort::inputs![ - model.input_params.get(0).expect("Hardcoded in parameters") + model.input_params.first().expect("Hardcoded in parameters") .as_str() => pixel_values_array.view(), ].map_err(|e| AIProxyError::ModelProviderPreprocessingError(e.to_string()))?; @@ -219,7 +219,7 @@ impl ORTProvider { match &model.session { Some(session) => { let session_inputs = ort::inputs![ - model.input_params.get(0) + model.input_params.first() .expect("Hardcoded in parameters").as_str() => Value::from_array(inputs_ids_array)?, model.input_params.get(1) .expect("Hardcoded in parameters").as_str() => Value::from_array(attention_mask_array.view())? @@ -279,8 +279,8 @@ impl TextPreprocessorTrait for ORTProvider { }; let tokens = tokens.iter().map(|x| *x as u32).collect::>(); - Ok(preprocessor.tokenizer.decode(&tokens, true) - .map_err(|_| {AIProxyError::ModelTokenizationError})?) + preprocessor.tokenizer.decode(&tokens, true) + .map_err(|_| {AIProxyError::ModelTokenizationError}) } } diff --git a/ahnlich/ai/src/engine/ai/providers/ort_helper.rs b/ahnlich/ai/src/engine/ai/providers/ort_helper.rs index f32ac23c..7a2a0228 100644 --- a/ahnlich/ai/src/engine/ai/providers/ort_helper.rs +++ b/ahnlich/ai/src/engine/ai/providers/ort_helper.rs @@ -1,11 +1,13 @@ +// This script was adapted from FastEmbed +// https://github.com/Anush008/fastembed-rs/blob/474d4e62c87666781b580ffc076b8475b693fc34/src/common.rs use hf_hub::api::sync::ApiRepo; +use rayon::prelude::*; use std::io::Read; use std::{fs::File, path::PathBuf}; use tokenizers::{AddedToken, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams}; use crate::error::AIProxyError; -// Tokenizer files for "bring your own" models #[derive(Debug, Clone, PartialEq, Eq)] pub struct TokenizerFiles { pub tokenizer_file: Vec, @@ -101,11 +103,11 @@ pub fn load_tokenizer(tokenizer_files: TokenizerFiles, max_length: usize) -> Res } pub fn normalize(v: &[f32]) -> Vec { - let norm = (v.iter().map(|val| val * val).sum::()).sqrt(); + let norm = (v.par_iter().map(|val| val * val).sum::()).sqrt(); let epsilon = 1e-12; // We add the super-small epsilon to avoid dividing by zero - v.iter().map(|&val| val / (norm + epsilon)).collect() + v.par_iter().map(|&val| val / (norm + epsilon)).collect() } /// Public function to read a file to bytes. diff --git a/ahnlich/ai/src/manager/mod.rs b/ahnlich/ai/src/manager/mod.rs index 2a1fc13d..603ac62f 100644 --- a/ahnlich/ai/src/manager/mod.rs +++ b/ahnlich/ai/src/manager/mod.rs @@ -127,14 +127,14 @@ impl ModelThread { let max_token_size: usize = max_token_size.into(); if (tokens.len() > max_token_size) && (string_action == StringAction::ErrorIfTokensExceed) { - return Err(AIProxyError::TokenExceededError { + Err(AIProxyError::TokenExceededError { input_token_size: tokens.len(), max_token_size, - }); + }) } else { tokens.truncate(max_token_size); let processed_input = provider.decode_tokens(tokens)?; - return Ok(processed_input); + Ok(processed_input) } }; diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_query.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_query.py index 37fb48e5..96b352d9 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_query.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_query.py @@ -1,10 +1,8 @@ # pyre-strict -import typing from dataclasses import dataclass - -from ahnlich_client_py.internals import bincode +import typing from ahnlich_client_py.internals import serde_types as st - +from ahnlich_client_py.internals import bincode class AIModel: VARIANTS = [] # type: typing.Sequence[typing.Type[AIModel]] @@ -13,10 +11,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIModel) @staticmethod - def bincode_deserialize(input: bytes) -> "AIModel": + def bincode_deserialize(input: bytes) -> 'AIModel': v, buffer = bincode.deserialize(input, AIModel) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -55,12 +53,6 @@ class AIModel__ClipVitB32Image(AIModel): INDEX = 5 # type: int pass -@dataclass(frozen=True) -class AIModel__ClipVitB32Text(AIModel): - INDEX = 6 # type: int - pass - - AIModel.VARIANTS = [ AIModel__AllMiniLML6V2, AIModel__AllMiniLML12V2, @@ -68,7 +60,6 @@ class AIModel__ClipVitB32Text(AIModel): AIModel__BGELargeEnV15, AIModel__Resnet50, AIModel__ClipVitB32Image, - AIModel__ClipVitB32Text, ] @@ -79,10 +70,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIQuery) @staticmethod - def bincode_deserialize(input: bytes) -> "AIQuery": + def bincode_deserialize(input: bytes) -> 'AIQuery': v, buffer = bincode.deserialize(input, AIQuery) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -149,9 +140,7 @@ class AIQuery__DropNonLinearAlgorithmIndex(AIQuery): class AIQuery__Set(AIQuery): INDEX = 7 # type: int store: str - inputs: typing.Sequence[ - typing.Tuple["StoreInput", typing.Dict[str, "MetadataValue"]] - ] + inputs: typing.Sequence[typing.Tuple["StoreInput", typing.Dict[str, "MetadataValue"]]] preprocess_action: "PreprocessAction" @@ -192,7 +181,6 @@ class AIQuery__Ping(AIQuery): INDEX = 13 # type: int pass - AIQuery.VARIANTS = [ AIQuery__CreateStore, AIQuery__GetPred, @@ -220,10 +208,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIServerQuery) @staticmethod - def bincode_deserialize(input: bytes) -> "AIServerQuery": + def bincode_deserialize(input: bytes) -> 'AIServerQuery': v, buffer = bincode.deserialize(input, AIServerQuery) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -234,10 +222,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIStoreInputType) @staticmethod - def bincode_deserialize(input: bytes) -> "AIStoreInputType": + def bincode_deserialize(input: bytes) -> 'AIStoreInputType': v, buffer = bincode.deserialize(input, AIStoreInputType) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -252,7 +240,6 @@ class AIStoreInputType__Image(AIStoreInputType): INDEX = 1 # type: int pass - AIStoreInputType.VARIANTS = [ AIStoreInputType__RawString, AIStoreInputType__Image, @@ -266,10 +253,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Algorithm) @staticmethod - def bincode_deserialize(input: bytes) -> "Algorithm": + def bincode_deserialize(input: bytes) -> 'Algorithm': v, buffer = bincode.deserialize(input, Algorithm) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -296,7 +283,6 @@ class Algorithm__KDTree(Algorithm): INDEX = 3 # type: int pass - Algorithm.VARIANTS = [ Algorithm__EuclideanDistance, Algorithm__DotProductSimilarity, @@ -312,10 +298,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ImageAction) @staticmethod - def bincode_deserialize(input: bytes) -> "ImageAction": + def bincode_deserialize(input: bytes) -> 'ImageAction': v, buffer = bincode.deserialize(input, ImageAction) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -330,7 +316,6 @@ class ImageAction__ErrorIfDimensionsMismatch(ImageAction): INDEX = 1 # type: int pass - ImageAction.VARIANTS = [ ImageAction__ResizeImage, ImageAction__ErrorIfDimensionsMismatch, @@ -344,10 +329,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, MetadataValue) @staticmethod - def bincode_deserialize(input: bytes) -> "MetadataValue": + def bincode_deserialize(input: bytes) -> 'MetadataValue': v, buffer = bincode.deserialize(input, MetadataValue) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -362,7 +347,6 @@ class MetadataValue__Image(MetadataValue): INDEX = 1 # type: int value: typing.Sequence[st.uint8] - MetadataValue.VARIANTS = [ MetadataValue__RawString, MetadataValue__Image, @@ -376,10 +360,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, NonLinearAlgorithm) @staticmethod - def bincode_deserialize(input: bytes) -> "NonLinearAlgorithm": + def bincode_deserialize(input: bytes) -> 'NonLinearAlgorithm': v, buffer = bincode.deserialize(input, NonLinearAlgorithm) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -388,7 +372,6 @@ class NonLinearAlgorithm__KDTree(NonLinearAlgorithm): INDEX = 0 # type: int pass - NonLinearAlgorithm.VARIANTS = [ NonLinearAlgorithm__KDTree, ] @@ -401,10 +384,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Predicate) @staticmethod - def bincode_deserialize(input: bytes) -> "Predicate": + def bincode_deserialize(input: bytes) -> 'Predicate': v, buffer = bincode.deserialize(input, Predicate) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -435,7 +418,6 @@ class Predicate__NotIn(Predicate): key: str value: typing.Sequence["MetadataValue"] - Predicate.VARIANTS = [ Predicate__Equals, Predicate__NotEquals, @@ -451,10 +433,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, PredicateCondition) @staticmethod - def bincode_deserialize(input: bytes) -> "PredicateCondition": + def bincode_deserialize(input: bytes) -> 'PredicateCondition': v, buffer = bincode.deserialize(input, PredicateCondition) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -475,7 +457,6 @@ class PredicateCondition__Or(PredicateCondition): INDEX = 2 # type: int value: typing.Tuple["PredicateCondition", "PredicateCondition"] - PredicateCondition.VARIANTS = [ PredicateCondition__Value, PredicateCondition__And, @@ -490,10 +471,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, PreprocessAction) @staticmethod - def bincode_deserialize(input: bytes) -> "PreprocessAction": + def bincode_deserialize(input: bytes) -> 'PreprocessAction': v, buffer = bincode.deserialize(input, PreprocessAction) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -508,7 +489,6 @@ class PreprocessAction__Image(PreprocessAction): INDEX = 1 # type: int value: "ImageAction" - PreprocessAction.VARIANTS = [ PreprocessAction__RawString, PreprocessAction__Image, @@ -522,10 +502,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StoreInput) @staticmethod - def bincode_deserialize(input: bytes) -> "StoreInput": + def bincode_deserialize(input: bytes) -> 'StoreInput': v, buffer = bincode.deserialize(input, StoreInput) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -540,7 +520,6 @@ class StoreInput__Image(StoreInput): INDEX = 1 # type: int value: typing.Sequence[st.uint8] - StoreInput.VARIANTS = [ StoreInput__RawString, StoreInput__Image, @@ -554,10 +533,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StringAction) @staticmethod - def bincode_deserialize(input: bytes) -> "StringAction": + def bincode_deserialize(input: bytes) -> 'StringAction': v, buffer = bincode.deserialize(input, StringAction) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -572,8 +551,8 @@ class StringAction__ErrorIfTokensExceed(StringAction): INDEX = 1 # type: int pass - StringAction.VARIANTS = [ StringAction__TruncateIfTokensExceed, StringAction__ErrorIfTokensExceed, ] + diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_response.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_response.py index 9e8cda32..838cb40d 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_response.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_response.py @@ -1,10 +1,8 @@ # pyre-strict -import typing from dataclasses import dataclass - -from ahnlich_client_py.internals import bincode +import typing from ahnlich_client_py.internals import serde_types as st - +from ahnlich_client_py.internals import bincode class AIModel: VARIANTS = [] # type: typing.Sequence[typing.Type[AIModel]] @@ -13,10 +11,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIModel) @staticmethod - def bincode_deserialize(input: bytes) -> "AIModel": + def bincode_deserialize(input: bytes) -> 'AIModel': v, buffer = bincode.deserialize(input, AIModel) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -55,7 +53,6 @@ class AIModel__ClipVitB32Image(AIModel): INDEX = 5 # type: int pass - AIModel.VARIANTS = [ AIModel__AllMiniLML6V2, AIModel__AllMiniLML12V2, @@ -73,10 +70,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIServerResponse) @staticmethod - def bincode_deserialize(input: bytes) -> "AIServerResponse": + def bincode_deserialize(input: bytes) -> 'AIServerResponse': v, buffer = bincode.deserialize(input, AIServerResponse) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -119,21 +116,13 @@ class AIServerResponse__Set(AIServerResponse): @dataclass(frozen=True) class AIServerResponse__Get(AIServerResponse): INDEX = 6 # type: int - value: typing.Sequence[ - typing.Tuple[typing.Optional["StoreInput"], typing.Dict[str, "MetadataValue"]] - ] + value: typing.Sequence[typing.Tuple[typing.Optional["StoreInput"], typing.Dict[str, "MetadataValue"]]] @dataclass(frozen=True) class AIServerResponse__GetSimN(AIServerResponse): INDEX = 7 # type: int - value: typing.Sequence[ - typing.Tuple[ - typing.Optional["StoreInput"], - typing.Dict[str, "MetadataValue"], - "Similarity", - ] - ] + value: typing.Sequence[typing.Tuple[typing.Optional["StoreInput"], typing.Dict[str, "MetadataValue"], "Similarity"]] @dataclass(frozen=True) @@ -147,7 +136,6 @@ class AIServerResponse__CreateIndex(AIServerResponse): INDEX = 9 # type: int value: st.uint64 - AIServerResponse.VARIANTS = [ AIServerResponse__Unit, AIServerResponse__Pong, @@ -170,10 +158,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIServerResult) @staticmethod - def bincode_deserialize(input: bytes) -> "AIServerResult": + def bincode_deserialize(input: bytes) -> 'AIServerResult': v, buffer = bincode.deserialize(input, AIServerResult) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -188,10 +176,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIStoreInfo) @staticmethod - def bincode_deserialize(input: bytes) -> "AIStoreInfo": + def bincode_deserialize(input: bytes) -> 'AIStoreInfo': v, buffer = bincode.deserialize(input, AIStoreInfo) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -202,10 +190,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIStoreInputType) @staticmethod - def bincode_deserialize(input: bytes) -> "AIStoreInputType": + def bincode_deserialize(input: bytes) -> 'AIStoreInputType': v, buffer = bincode.deserialize(input, AIStoreInputType) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -220,7 +208,6 @@ class AIStoreInputType__Image(AIStoreInputType): INDEX = 1 # type: int pass - AIStoreInputType.VARIANTS = [ AIStoreInputType__RawString, AIStoreInputType__Image, @@ -236,10 +223,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ConnectedClient) @staticmethod - def bincode_deserialize(input: bytes) -> "ConnectedClient": + def bincode_deserialize(input: bytes) -> 'ConnectedClient': v, buffer = bincode.deserialize(input, ConnectedClient) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -250,10 +237,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, MetadataValue) @staticmethod - def bincode_deserialize(input: bytes) -> "MetadataValue": + def bincode_deserialize(input: bytes) -> 'MetadataValue': v, buffer = bincode.deserialize(input, MetadataValue) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -268,7 +255,6 @@ class MetadataValue__Image(MetadataValue): INDEX = 1 # type: int value: typing.Sequence[st.uint8] - MetadataValue.VARIANTS = [ MetadataValue__RawString, MetadataValue__Image, @@ -282,10 +268,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Result) @staticmethod - def bincode_deserialize(input: bytes) -> "Result": + def bincode_deserialize(input: bytes) -> 'Result': v, buffer = bincode.deserialize(input, Result) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -300,7 +286,6 @@ class Result__Err(Result): INDEX = 1 # type: int value: str - Result.VARIANTS = [ Result__Ok, Result__Err, @@ -319,10 +304,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerInfo) @staticmethod - def bincode_deserialize(input: bytes) -> "ServerInfo": + def bincode_deserialize(input: bytes) -> 'ServerInfo': v, buffer = bincode.deserialize(input, ServerInfo) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -333,10 +318,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerType) @staticmethod - def bincode_deserialize(input: bytes) -> "ServerType": + def bincode_deserialize(input: bytes) -> 'ServerType': v, buffer = bincode.deserialize(input, ServerType) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -351,7 +336,6 @@ class ServerType__AI(ServerType): INDEX = 1 # type: int pass - ServerType.VARIANTS = [ ServerType__Database, ServerType__AI, @@ -366,10 +350,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Similarity) @staticmethod - def bincode_deserialize(input: bytes) -> "Similarity": + def bincode_deserialize(input: bytes) -> 'Similarity': v, buffer = bincode.deserialize(input, Similarity) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -380,10 +364,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StoreInput) @staticmethod - def bincode_deserialize(input: bytes) -> "StoreInput": + def bincode_deserialize(input: bytes) -> 'StoreInput': v, buffer = bincode.deserialize(input, StoreInput) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -398,7 +382,6 @@ class StoreInput__Image(StoreInput): INDEX = 1 # type: int value: typing.Sequence[st.uint8] - StoreInput.VARIANTS = [ StoreInput__RawString, StoreInput__Image, @@ -414,10 +397,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StoreUpsert) @staticmethod - def bincode_deserialize(input: bytes) -> "StoreUpsert": + def bincode_deserialize(input: bytes) -> 'StoreUpsert': v, buffer = bincode.deserialize(input, StoreUpsert) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -430,10 +413,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, SystemTime) @staticmethod - def bincode_deserialize(input: bytes) -> "SystemTime": + def bincode_deserialize(input: bytes) -> 'SystemTime': v, buffer = bincode.deserialize(input, SystemTime) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -447,8 +430,9 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Version) @staticmethod - def bincode_deserialize(input: bytes) -> "Version": + def bincode_deserialize(input: bytes) -> 'Version': v, buffer = bincode.deserialize(input, Version) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v + diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/bincode/__init__.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/bincode/__init__.py index 38cbd7ff..4e5e0837 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/bincode/__init__.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/bincode/__init__.py @@ -1,16 +1,16 @@ # Copyright (c) Facebook, Inc. and its affiliates # SPDX-License-Identifier: MIT OR Apache-2.0 -import collections import dataclasses +import collections import io import struct import typing from copy import copy from typing import get_type_hints -from ahnlich_client_py.internals import serde_binary as sb from ahnlich_client_py.internals import serde_types as st +from ahnlich_client_py.internals import serde_binary as sb # Maximum length in practice for sequences (e.g. in Java). MAX_LENGTH = (1 << 31) - 1 diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_query.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_query.py index b281f346..b6120b2b 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_query.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_query.py @@ -1,10 +1,8 @@ # pyre-strict -import typing from dataclasses import dataclass - -from ahnlich_client_py.internals import bincode +import typing from ahnlich_client_py.internals import serde_types as st - +from ahnlich_client_py.internals import bincode class Algorithm: VARIANTS = [] # type: typing.Sequence[typing.Type[Algorithm]] @@ -13,10 +11,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Algorithm) @staticmethod - def bincode_deserialize(input: bytes) -> "Algorithm": + def bincode_deserialize(input: bytes) -> 'Algorithm': v, buffer = bincode.deserialize(input, Algorithm) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -43,7 +41,6 @@ class Algorithm__KDTree(Algorithm): INDEX = 3 # type: int pass - Algorithm.VARIANTS = [ Algorithm__EuclideanDistance, Algorithm__DotProductSimilarity, @@ -62,10 +59,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Array) @staticmethod - def bincode_deserialize(input: bytes) -> "Array": + def bincode_deserialize(input: bytes) -> 'Array': v, buffer = bincode.deserialize(input, Array) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -76,10 +73,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, MetadataValue) @staticmethod - def bincode_deserialize(input: bytes) -> "MetadataValue": + def bincode_deserialize(input: bytes) -> 'MetadataValue': v, buffer = bincode.deserialize(input, MetadataValue) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -94,7 +91,6 @@ class MetadataValue__Image(MetadataValue): INDEX = 1 # type: int value: typing.Sequence[st.uint8] - MetadataValue.VARIANTS = [ MetadataValue__RawString, MetadataValue__Image, @@ -108,10 +104,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, NonLinearAlgorithm) @staticmethod - def bincode_deserialize(input: bytes) -> "NonLinearAlgorithm": + def bincode_deserialize(input: bytes) -> 'NonLinearAlgorithm': v, buffer = bincode.deserialize(input, NonLinearAlgorithm) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -120,7 +116,6 @@ class NonLinearAlgorithm__KDTree(NonLinearAlgorithm): INDEX = 0 # type: int pass - NonLinearAlgorithm.VARIANTS = [ NonLinearAlgorithm__KDTree, ] @@ -133,10 +128,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Predicate) @staticmethod - def bincode_deserialize(input: bytes) -> "Predicate": + def bincode_deserialize(input: bytes) -> 'Predicate': v, buffer = bincode.deserialize(input, Predicate) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -167,7 +162,6 @@ class Predicate__NotIn(Predicate): key: str value: typing.Sequence["MetadataValue"] - Predicate.VARIANTS = [ Predicate__Equals, Predicate__NotEquals, @@ -183,10 +177,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, PredicateCondition) @staticmethod - def bincode_deserialize(input: bytes) -> "PredicateCondition": + def bincode_deserialize(input: bytes) -> 'PredicateCondition': v, buffer = bincode.deserialize(input, PredicateCondition) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -207,7 +201,6 @@ class PredicateCondition__Or(PredicateCondition): INDEX = 2 # type: int value: typing.Tuple["PredicateCondition", "PredicateCondition"] - PredicateCondition.VARIANTS = [ PredicateCondition__Value, PredicateCondition__And, @@ -222,10 +215,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Query) @staticmethod - def bincode_deserialize(input: bytes) -> "Query": + def bincode_deserialize(input: bytes) -> 'Query': v, buffer = bincode.deserialize(input, Query) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -344,7 +337,6 @@ class Query__Ping(Query): INDEX = 15 # type: int pass - Query.VARIANTS = [ Query__CreateStore, Query__GetKey, @@ -374,8 +366,9 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerQuery) @staticmethod - def bincode_deserialize(input: bytes) -> "ServerQuery": + def bincode_deserialize(input: bytes) -> 'ServerQuery': v, buffer = bincode.deserialize(input, ServerQuery) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v + diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_response.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_response.py index d1d0a6c4..acd3baa1 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_response.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_response.py @@ -1,10 +1,8 @@ # pyre-strict -import typing from dataclasses import dataclass - -from ahnlich_client_py.internals import bincode +import typing from ahnlich_client_py.internals import serde_types as st - +from ahnlich_client_py.internals import bincode @dataclass(frozen=True) class Array: @@ -16,10 +14,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Array) @staticmethod - def bincode_deserialize(input: bytes) -> "Array": + def bincode_deserialize(input: bytes) -> 'Array': v, buffer = bincode.deserialize(input, Array) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -32,10 +30,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ConnectedClient) @staticmethod - def bincode_deserialize(input: bytes) -> "ConnectedClient": + def bincode_deserialize(input: bytes) -> 'ConnectedClient': v, buffer = bincode.deserialize(input, ConnectedClient) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -46,10 +44,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, MetadataValue) @staticmethod - def bincode_deserialize(input: bytes) -> "MetadataValue": + def bincode_deserialize(input: bytes) -> 'MetadataValue': v, buffer = bincode.deserialize(input, MetadataValue) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -64,7 +62,6 @@ class MetadataValue__Image(MetadataValue): INDEX = 1 # type: int value: typing.Sequence[st.uint8] - MetadataValue.VARIANTS = [ MetadataValue__RawString, MetadataValue__Image, @@ -78,10 +75,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Result) @staticmethod - def bincode_deserialize(input: bytes) -> "Result": + def bincode_deserialize(input: bytes) -> 'Result': v, buffer = bincode.deserialize(input, Result) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -96,7 +93,6 @@ class Result__Err(Result): INDEX = 1 # type: int value: str - Result.VARIANTS = [ Result__Ok, Result__Err, @@ -115,10 +111,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerInfo) @staticmethod - def bincode_deserialize(input: bytes) -> "ServerInfo": + def bincode_deserialize(input: bytes) -> 'ServerInfo': v, buffer = bincode.deserialize(input, ServerInfo) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -129,10 +125,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerResponse) @staticmethod - def bincode_deserialize(input: bytes) -> "ServerResponse": + def bincode_deserialize(input: bytes) -> 'ServerResponse': v, buffer = bincode.deserialize(input, ServerResponse) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -181,9 +177,7 @@ class ServerResponse__Get(ServerResponse): @dataclass(frozen=True) class ServerResponse__GetSimN(ServerResponse): INDEX = 7 # type: int - value: typing.Sequence[ - typing.Tuple["Array", typing.Dict[str, "MetadataValue"], "Similarity"] - ] + value: typing.Sequence[typing.Tuple["Array", typing.Dict[str, "MetadataValue"], "Similarity"]] @dataclass(frozen=True) @@ -197,7 +191,6 @@ class ServerResponse__CreateIndex(ServerResponse): INDEX = 9 # type: int value: st.uint64 - ServerResponse.VARIANTS = [ ServerResponse__Unit, ServerResponse__Pong, @@ -220,10 +213,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerResult) @staticmethod - def bincode_deserialize(input: bytes) -> "ServerResult": + def bincode_deserialize(input: bytes) -> 'ServerResult': v, buffer = bincode.deserialize(input, ServerResult) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -234,10 +227,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerType) @staticmethod - def bincode_deserialize(input: bytes) -> "ServerType": + def bincode_deserialize(input: bytes) -> 'ServerType': v, buffer = bincode.deserialize(input, ServerType) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -252,7 +245,6 @@ class ServerType__AI(ServerType): INDEX = 1 # type: int pass - ServerType.VARIANTS = [ ServerType__Database, ServerType__AI, @@ -267,10 +259,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Similarity) @staticmethod - def bincode_deserialize(input: bytes) -> "Similarity": + def bincode_deserialize(input: bytes) -> 'Similarity': v, buffer = bincode.deserialize(input, Similarity) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -284,10 +276,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StoreInfo) @staticmethod - def bincode_deserialize(input: bytes) -> "StoreInfo": + def bincode_deserialize(input: bytes) -> 'StoreInfo': v, buffer = bincode.deserialize(input, StoreInfo) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -300,10 +292,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StoreUpsert) @staticmethod - def bincode_deserialize(input: bytes) -> "StoreUpsert": + def bincode_deserialize(input: bytes) -> 'StoreUpsert': v, buffer = bincode.deserialize(input, StoreUpsert) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -316,10 +308,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, SystemTime) @staticmethod - def bincode_deserialize(input: bytes) -> "SystemTime": + def bincode_deserialize(input: bytes) -> 'SystemTime': v, buffer = bincode.deserialize(input, SystemTime) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -333,8 +325,9 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Version) @staticmethod - def bincode_deserialize(input: bytes) -> "Version": + def bincode_deserialize(input: bytes) -> 'Version': v, buffer = bincode.deserialize(input, Version) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v + diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_binary/__init__.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_binary/__init__.py index a71b03f5..0730bd23 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_binary/__init__.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_binary/__init__.py @@ -7,8 +7,8 @@ Note: This internal module is currently only meant to share code between the BCS and bincode formats. Internal APIs could change in the future. """ -import collections import dataclasses +import collections import io import typing from typing import get_type_hints diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_types/__init__.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_types/__init__.py index 1c85909c..6d72f027 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_types/__init__.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_types/__init__.py @@ -1,10 +1,9 @@ # Copyright (c) Facebook, Inc. and its affiliates # SPDX-License-Identifier: MIT OR Apache-2.0 -import typing -from dataclasses import dataclass - import numpy as np +from dataclasses import dataclass +import typing class SerializationError(ValueError): From 40cb9ceaf8457984cb366baa074290c27954fed6 Mon Sep 17 00:00:00 2001 From: HabeebShopeju Date: Fri, 15 Nov 2024 22:59:46 +0000 Subject: [PATCH 03/15] I got the models to work! --- ahnlich/ai/src/engine/ai/models.rs | 88 ++++++----- .../ai/src/engine/ai/providers/fastembed.rs | 8 +- ahnlich/ai/src/engine/ai/providers/mod.rs | 7 +- ahnlich/ai/src/engine/ai/providers/ort.rs | 65 ++++---- .../ai/src/engine/ai/providers/ort_helper.rs | 144 +++++------------- .../engine/ai/providers/ort_text_helper.rs | 116 ++++++++++++++ .../ai/providers/processors/center_crop.rs | 122 +++++++++++++++ .../processors/imagearray_to_ndarray.rs | 29 ++++ .../src/engine/ai/providers/processors/mod.rs | 21 +++ .../ai/providers/processors/normalize.rs | 78 ++++++++++ .../ai/providers/processors/preprocessor.rs | 127 +++++++++++++++ .../engine/ai/providers/processors/rescale.rs | 45 ++++++ .../engine/ai/providers/processors/resize.rs | 108 +++++++++++++ ahnlich/ai/src/engine/store.rs | 2 +- ahnlich/ai/src/error.rs | 41 ++++- ahnlich/ai/src/manager/mod.rs | 13 +- ahnlich/ai/src/server/task.rs | 73 ++++----- ahnlich/types/src/ai/preprocess.rs | 2 + ahnlich/utils/src/cli.rs | 2 +- .../ahnlich_client_py/internals/ai_query.py | 81 ++++++---- .../internals/ai_response.py | 86 ++++++----- .../internals/bincode/__init__.py | 4 +- .../ahnlich_client_py/internals/db_query.py | 45 +++--- .../internals/db_response.py | 67 ++++---- .../internals/serde_binary/__init__.py | 2 +- .../internals/serde_types/__init__.py | 5 +- sdk/ahnlich-client-py/demo_embed.py | 98 +++++++++--- 27 files changed, 1103 insertions(+), 376 deletions(-) create mode 100644 ahnlich/ai/src/engine/ai/providers/ort_text_helper.rs create mode 100644 ahnlich/ai/src/engine/ai/providers/processors/center_crop.rs create mode 100644 ahnlich/ai/src/engine/ai/providers/processors/imagearray_to_ndarray.rs create mode 100644 ahnlich/ai/src/engine/ai/providers/processors/mod.rs create mode 100644 ahnlich/ai/src/engine/ai/providers/processors/normalize.rs create mode 100644 ahnlich/ai/src/engine/ai/providers/processors/preprocessor.rs create mode 100644 ahnlich/ai/src/engine/ai/providers/processors/rescale.rs create mode 100644 ahnlich/ai/src/engine/ai/providers/processors/resize.rs diff --git a/ahnlich/ai/src/engine/ai/models.rs b/ahnlich/ai/src/engine/ai/models.rs index f430ef35..093f45ce 100644 --- a/ahnlich/ai/src/engine/ai/models.rs +++ b/ahnlich/ai/src/engine/ai/models.rs @@ -8,7 +8,7 @@ use ahnlich_types::{ ai::{AIModel, AIStoreInputType}, keyval::{StoreInput, StoreKey}, }; -use image::{GenericImageView, ImageReader}; +use image::{DynamicImage, GenericImageView, ImageFormat, ImageReader}; use ndarray::ArrayView; use ndarray::{Array, Ix3}; use nonzero_ext::nonzero; @@ -18,6 +18,8 @@ use std::io::Cursor; use std::num::NonZeroUsize; use std::path::Path; use strum::Display; +use serde::ser::Error as SerError; +use serde::de::Error as DeError; #[derive(Display)] pub enum ModelType { @@ -263,7 +265,8 @@ pub enum ModelInput { #[derive(Debug, Clone)] pub struct ImageArray { array: Array, - bytes: Vec, + image: DynamicImage, + image_format: ImageFormat } impl ImageArray { @@ -272,10 +275,18 @@ impl ImageArray { .with_guessed_format() .map_err(|_| AIProxyError::ImageBytesDecodeError)?; - let img = img_reader + let image_format = &img_reader + .format() + .ok_or(AIProxyError::ImageBytesDecodeError)?; + + let image = img_reader .decode() .map_err(|_| AIProxyError::ImageBytesDecodeError)?; - let (width, height) = img.dimensions(); + + // Always convert to RGB8 format + // https://github.com/Anush008/fastembed-rs/blob/cea92b6c8b877efda762393848d1c449a4eea126/src/image_embedding/utils.rs#L198 + let image: DynamicImage = image.to_owned().into_rgb8().into(); + let (width, height) = image.dimensions(); if width == 0 || height == 0 { return Err(AIProxyError::ImageNonzeroDimensionError { @@ -284,12 +295,13 @@ impl ImageArray { }); } - let channels = img.color().channel_count(); - let shape = (height as usize, width as usize, channels as usize); - let array = Array::from_shape_vec(shape, img.into_bytes()) + let channels = &image.color().channel_count(); + let shape = (height as usize, width as usize, *channels as usize); + let array = Array::from_shape_vec(shape, image.clone().into_bytes()) .map_err(|_| AIProxyError::ImageBytesDecodeError)? .mapv(f32::from); - Ok(ImageArray { array, bytes }) + + Ok(ImageArray { array, image, image_format: image_format.to_owned() }) } // Swapping axes from [rows, columns, channels] to [channels, rows, columns] for ONNX @@ -302,42 +314,42 @@ impl ImageArray { self.array.view() } - pub fn get_bytes(&self) -> &Vec { - &self.bytes + pub fn get_bytes(&self) -> Result, AIProxyError> { + let mut buffer = Cursor::new(Vec::new()); + let _ = &self.image + .write_to(&mut buffer, self.image_format) + .map_err(|_| AIProxyError::ImageBytesEncodeError)?; + let bytes = buffer.into_inner(); + Ok(bytes) } - pub fn resize(&self, width: NonZeroUsize, height: NonZeroUsize) -> Result { - let width = usize::from(width); - let height = usize::from(height); - let img_reader = ImageReader::new(Cursor::new(&self.bytes)) - .with_guessed_format() - .map_err(|_| AIProxyError::ImageBytesDecodeError)?; - let img_format = img_reader - .format() - .ok_or(AIProxyError::ImageBytesDecodeError)?; - let original_img = img_reader - .decode() - .map_err(|_| AIProxyError::ImageBytesDecodeError)?; - - let resized_img = original_img.resize_exact( - width as u32, - height as u32, - image::imageops::FilterType::Triangle, + pub fn resize(&self, width: u32, height: u32, filter: Option) -> Result { + let filter_type = filter.unwrap_or(image::imageops::FilterType::CatmullRom); + let resized_img = self.image.resize_exact( + width, + height, + filter_type, ); let channels = resized_img.color().channel_count(); - let shape = (height, width, channels as usize); - - let mut buffer = Cursor::new(Vec::new()); - resized_img - .write_to(&mut buffer, img_format) - .map_err(|_| AIProxyError::ImageResizeError)?; + let shape = (height as usize, width as usize, channels as usize); - let flattened_pixels = resized_img.into_bytes(); + let flattened_pixels = resized_img.clone().into_bytes(); let array = Array::from_shape_vec(shape, flattened_pixels) .map_err(|_| AIProxyError::ImageResizeError)? .mapv(f32::from); - let bytes = buffer.into_inner(); - Ok(ImageArray { array, bytes }) + Ok(ImageArray { array, image: resized_img, image_format: self.image_format }) + } + + pub fn crop(&self, x: u32, y: u32, width: u32, height: u32) -> Result { + let cropped_img = self.image.crop_imm(x, y, width, height); + let channels = cropped_img.color().channel_count(); + let shape = (height as usize, width as usize, channels as usize); + + let flattened_pixels = cropped_img.clone().into_bytes(); + let array = Array::from_shape_vec(shape, flattened_pixels) + .map_err(|_| AIProxyError::ImageCropError)? + .mapv(f32::from); + Ok(ImageArray { array, image: cropped_img, image_format: self.image_format }) } pub fn image_dim(&self) -> (NonZeroUsize, NonZeroUsize) { @@ -354,7 +366,7 @@ impl Serialize for ImageArray { where S: Serializer, { - serializer.serialize_bytes(self.get_bytes()) + serializer.serialize_bytes(&self.get_bytes().map_err(S::Error::custom)?) } } @@ -364,7 +376,7 @@ impl<'de> Deserialize<'de> for ImageArray { D: Deserializer<'de>, { let bytes: Vec = Deserialize::deserialize(deserializer)?; - ImageArray::try_new(bytes).map_err(serde::de::Error::custom) + ImageArray::try_new(bytes).map_err(D::Error::custom) } } diff --git a/ahnlich/ai/src/engine/ai/providers/fastembed.rs b/ahnlich/ai/src/engine/ai/providers/fastembed.rs index 98e849c8..0ea2431f 100644 --- a/ahnlich/ai/src/engine/ai/providers/fastembed.rs +++ b/ahnlich/ai/src/engine/ai/providers/fastembed.rs @@ -41,17 +41,18 @@ pub struct FastEmbedPreprocessor { // TODO (HAKSOAT): Implement other preprocessors impl TextPreprocessorTrait for FastEmbedProvider { - fn encode_str(&self, text: &str) -> Result, AIProxyError> { + fn encode_str(&self, text: &str) -> Result, AIProxyError> { let preprocessor = self.preprocessor.as_ref() .ok_or(AIProxyError::AIModelNotInitialized)?; let tokens = preprocessor .tokenizer.0.encode_with_special_tokens(text); - Ok(tokens) + Ok(tokens.iter().map(|token| *token as u32).collect()) } - fn decode_tokens(&self, tokens: Vec) -> Result { + fn decode_tokens(&self, tokens: Vec) -> Result { let preprocessor = self.preprocessor.as_ref() .ok_or(AIProxyError::AIModelNotInitialized)?; + let tokens = tokens.iter().map(|token| *token as usize).collect(); let text = preprocessor .tokenizer.0.decode(tokens) .map_err(|_| AIProxyError::ModelTokenizationError)?; @@ -170,7 +171,6 @@ impl ProviderTrait for FastEmbedProvider { Ok(()) } } - // TODO (HAKSOAT): When we add model specific tokenizers, add the get tokenizer call here too. } fn run_inference( diff --git a/ahnlich/ai/src/engine/ai/providers/mod.rs b/ahnlich/ai/src/engine/ai/providers/mod.rs index e00dc2f2..c241601a 100644 --- a/ahnlich/ai/src/engine/ai/providers/mod.rs +++ b/ahnlich/ai/src/engine/ai/providers/mod.rs @@ -1,6 +1,8 @@ pub(crate) mod fastembed; pub(crate) mod ort; +mod ort_text_helper; mod ort_helper; +mod processors; use crate::cli::server::SupportedModels; use crate::engine::ai::models::{InputAction, ModelInput}; @@ -11,6 +13,7 @@ use ahnlich_types::keyval::StoreKey; use std::path::Path; use strum::EnumIter; + #[derive(Debug, EnumIter)] pub enum ModelProviders { FastEmbed(FastEmbedProvider), @@ -30,6 +33,6 @@ pub trait ProviderTrait: std::fmt::Debug + Send + Sync { } pub trait TextPreprocessorTrait { - fn encode_str(&self, text: &str) -> Result, AIProxyError>; - fn decode_tokens(&self, tokens: Vec) -> Result; + fn encode_str(&self, text: &str) -> Result, AIProxyError>; + fn decode_tokens(&self, tokens: Vec) -> Result; } diff --git a/ahnlich/ai/src/engine/ai/providers/ort.rs b/ahnlich/ai/src/engine/ai/providers/ort.rs index 1c27ec70..0f4851e8 100644 --- a/ahnlich/ai/src/engine/ai/providers/ort.rs +++ b/ahnlich/ai/src/engine/ai/providers/ort.rs @@ -1,7 +1,7 @@ use crate::cli::server::SupportedModels; use crate::engine::ai::models::{ImageArray, InputAction, Model, ModelInput}; -use crate::engine::ai::providers::ort_helper::{get_tokenizer_artifacts_hf_hub, normalize, - load_tokenizer_artifacts_hf_hub}; +use crate::engine::ai::providers::ort_text_helper::{get_tokenizer_artifacts_hf_hub, load_tokenizer_artifacts_hf_hub, + normalize}; use crate::engine::ai::providers::{ProviderTrait, TextPreprocessorTrait}; use crate::error::AIProxyError; use fallible_collections::FallibleVec; @@ -9,17 +9,16 @@ use hf_hub::{api::sync::ApiBuilder, Cache}; use itertools::Itertools; use rayon::iter::Either; use ort::{Session, Value}; -use tokenizers::Tokenizer; use rayon::prelude::*; use ahnlich_types::keyval::StoreKey; -use ndarray::{Array, Array1, ArrayView, Axis, Ix3}; +use ndarray::{Array, Array1, Axis, Ix4}; use std::convert::TryFrom; use std::default::Default; use std::fmt; use std::path::{Path, PathBuf}; use std::thread::available_parallelism; - +use crate::engine::ai::providers::processors::preprocessor::{ImagePreprocessorFiles, ORTImagePreprocessor, ORTPreprocessor, ORTTextPreprocessor}; #[derive(Default)] pub struct ORTProvider { cache_location: Option, @@ -46,6 +45,7 @@ pub struct ORTImageModel { session: Option, input_params: Vec, output_param: String, + preprocessor_files: ImagePreprocessorFiles } #[derive(Default)] @@ -62,13 +62,6 @@ pub enum ORTModel { Text(ORTTextModel) } -pub enum ORTPreprocessor { - Text(ORTTextPreprocessor), -} - -pub struct ORTTextPreprocessor { - tokenizer: Tokenizer, -} impl TryFrom<&SupportedModels> for ORTModel { type Error = AIProxyError; @@ -116,22 +109,27 @@ impl ORTProvider { } } - pub fn batch_inference_image(&self, mut inputs: Vec) -> Result, AIProxyError> { + pub fn preprocess(&self, data: Vec) -> Result, AIProxyError> { + match &self.preprocessor { + Some(ORTPreprocessor::Image(preprocessor)) => { + let output_data = preprocessor.process(data) + .map_err( + |e| AIProxyError::ModelProviderPreprocessingError( + format!("Preprocessing failed for {:?} with error: {}", + self.supported_models.unwrap().to_string(), e) + ))?; + Ok(output_data) + } + _ => Err(AIProxyError::AIModelNotInitialized) + } + } + + pub fn batch_inference_image(&self, inputs: Vec) -> Result, AIProxyError> { let model = match &self.model { Some(ORTModel::Image(model)) => model, _ => return Err(AIProxyError::AIModelNotSupported { model_name: self.supported_models.unwrap().to_string() }), }; - - let array_views: Vec> = inputs - .par_iter_mut() - .map(|image_arr| { - image_arr.onnx_transform(); - image_arr.view() - }) - .collect(); - - let pixel_values_array = ndarray::stack(ndarray::Axis(0), &array_views) - .map_err(|e| AIProxyError::EmbeddingShapeError(e.to_string()))?; + let pixel_values_array = self.preprocess(inputs)?; match &model.session { Some(session) => { let session_inputs = ort::inputs![ @@ -189,19 +187,16 @@ impl ORTProvider { // Preallocate arrays with the maximum size let mut ids_array = Vec::with_capacity(max_size); let mut mask_array = Vec::with_capacity(max_size); - let mut typeids_array = Vec::with_capacity(max_size); // Not using par_iter because the closure needs to be FnMut encodings.iter().for_each(|encoding| { let ids = encoding.get_ids(); let mask = encoding.get_attention_mask(); - let typeids = encoding.get_type_ids(); // Extend the preallocated arrays with the current encoding // Requires the closure to be FnMut ids_array.extend(ids.iter().map(|x| *x as i64)); mask_array.extend(mask.iter().map(|x| *x as i64)); - typeids_array.extend(typeids.iter().map(|x| *x as i64)); }); // Create CowArrays from vectors @@ -253,7 +248,7 @@ impl ORTProvider { } impl TextPreprocessorTrait for ORTProvider { - fn encode_str(&self, text: &str) -> Result, AIProxyError> { + fn encode_str(&self, text: &str) -> Result, AIProxyError> { match &self.model { Some(ORTModel::Text(model)) => model, _ => return Err(AIProxyError::AIModelNotSupported { model_name: self.supported_models.unwrap().to_string() }), @@ -265,10 +260,11 @@ impl TextPreprocessorTrait for ORTProvider { let tokens = preprocessor.tokenizer.encode(text, true) .map_err(|_| {AIProxyError::ModelTokenizationError})?; - Ok(tokens.get_ids().iter().map(|x| *x as usize).collect()) + let token_ids = tokens.get_ids(); + Ok(token_ids.to_vec()) } - fn decode_tokens(&self, tokens: Vec) -> Result { + fn decode_tokens(&self, token_ids: Vec) -> Result { match &self.model { Some(ORTModel::Text(model)) => model, _ => return Err(AIProxyError::AIModelNotSupported { model_name: self.supported_models.unwrap().to_string() }), @@ -278,8 +274,7 @@ impl TextPreprocessorTrait for ORTProvider { return Err(AIProxyError::AIModelNotInitialized); }; - let tokens = tokens.iter().map(|x| *x as u32).collect::>(); - preprocessor.tokenizer.decode(&tokens, true) + preprocessor.tokenizer.decode(&token_ids, true) .map_err(|_| {AIProxyError::ModelTokenizationError}) } } @@ -318,6 +313,7 @@ impl ProviderTrait for ORTProvider { repo_name, input_params: input_param, output_param, + preprocessor_files, .. }) => { let model_repo = api.model(repo_name.clone()); @@ -333,7 +329,12 @@ impl ProviderTrait for ORTProvider { input_params: input_param, output_param, session: Some(session), + ..Default::default() })); + let mut preprocessor = ORTImagePreprocessor::default(); + preprocessor.load(model_repo, preprocessor_files)?; + self.preprocessor = Some(ORTPreprocessor::Image(preprocessor) + ); }, ORTModel::Text(ORTTextModel { weights_file, diff --git a/ahnlich/ai/src/engine/ai/providers/ort_helper.rs b/ahnlich/ai/src/engine/ai/providers/ort_helper.rs index 7a2a0228..e665526d 100644 --- a/ahnlich/ai/src/engine/ai/providers/ort_helper.rs +++ b/ahnlich/ai/src/engine/ai/providers/ort_helper.rs @@ -1,114 +1,10 @@ -// This script was adapted from FastEmbed -// https://github.com/Anush008/fastembed-rs/blob/474d4e62c87666781b580ffc076b8475b693fc34/src/common.rs +use crate::error::AIProxyError; use hf_hub::api::sync::ApiRepo; -use rayon::prelude::*; +use serde_json; +use std::collections::HashMap; +use std::fs::File; use std::io::Read; -use std::{fs::File, path::PathBuf}; -use tokenizers::{AddedToken, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams}; -use crate::error::AIProxyError; - - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct TokenizerFiles { - pub tokenizer_file: Vec, - pub config_file: Vec, - pub special_tokens_map_file: Vec, - pub tokenizer_config_file: Vec, -} - -/// The procedure for loading tokenizer files from the hugging face hub is separated -/// from the main load_tokenizer function (which is expecting bytes, from any source). -pub fn load_tokenizer_artifacts_hf_hub(model_repo: &ApiRepo, max_length: usize) -> Result { - let tokenizer_files: TokenizerFiles = get_tokenizer_artifacts_hf_hub(model_repo)?; - load_tokenizer(tokenizer_files, max_length) -} - -pub fn get_tokenizer_artifacts_hf_hub(model_repo: &ApiRepo) -> Result { - Ok(TokenizerFiles { - tokenizer_file: read_file_to_bytes(&model_repo.get("tokenizer.json") - .map_err(|_| AIProxyError::ModelTokenizerLoadError)?) - .map_err(|_| AIProxyError::ModelTokenizerLoadError)?, - config_file: read_file_to_bytes(&model_repo.get("config.json") - .map_err(|_| AIProxyError::ModelTokenizerLoadError)?) - .map_err(|_| AIProxyError::ModelTokenizerLoadError)?, - special_tokens_map_file: read_file_to_bytes(&model_repo.get("special_tokens_map.json") - .map_err(|_| AIProxyError::ModelTokenizerLoadError)?) - .map_err(|_| AIProxyError::ModelTokenizerLoadError)?, - tokenizer_config_file: read_file_to_bytes(&model_repo.get("tokenizer_config.json") - .map_err(|_| AIProxyError::ModelTokenizerLoadError)?) - .map_err(|_| AIProxyError::ModelTokenizerLoadError)?, - }) -} - -/// Function can be called directly from the try_new_from_user_defined function (providing file bytes) -/// -/// Or indirectly from the try_new function via load_tokenizer_hf_hub (converting HF files to bytes) -pub fn load_tokenizer(tokenizer_files: TokenizerFiles, max_length: usize) -> Result { - // Serialise each tokenizer file - let config: serde_json::Value = - serde_json::from_slice(&tokenizer_files.config_file).map_err(|_| AIProxyError::ModelTokenizerLoadError)?; - let special_tokens_map: serde_json::Value = - serde_json::from_slice(&tokenizer_files.special_tokens_map_file).map_err(|_| AIProxyError::ModelTokenizerLoadError)?; - let tokenizer_config: serde_json::Value = - serde_json::from_slice(&tokenizer_files.tokenizer_config_file).map_err(|_| AIProxyError::ModelTokenizerLoadError)?; - let mut tokenizer: tokenizers::Tokenizer = - tokenizers::Tokenizer::from_bytes(tokenizer_files.tokenizer_file).map_err(|_| AIProxyError::ModelTokenizerLoadError)?; - - //For BGEBaseSmall, the model_max_length value is set to 1000000000000000019884624838656. Which fits in a f64 - let model_max_length = tokenizer_config["model_max_length"] - .as_f64() - .ok_or(AIProxyError::ModelTokenizerLoadError)? as f32; - let max_length = max_length.min(model_max_length as usize); - let pad_id = config["pad_token_id"].as_u64().unwrap_or(0) as u32; - let pad_token = tokenizer_config["pad_token"] - .as_str() - .ok_or(AIProxyError::ModelTokenizerLoadError)? - .into(); - - let mut tokenizer = tokenizer - .with_padding(Some(PaddingParams { - // TODO: the user should able to choose the padding strategy - strategy: PaddingStrategy::BatchLongest, - pad_token, - pad_id, - ..Default::default() - })) - .with_truncation(Some(TruncationParams { - max_length, - ..Default::default() - })) - .map_err(|_| AIProxyError::ModelTokenizerLoadError)? - .clone(); - if let serde_json::Value::Object(root_object) = special_tokens_map { - for (_, value) in root_object.iter() { - if value.is_string() { - tokenizer.add_special_tokens(&[AddedToken { - content: value.as_str().unwrap().into(), - special: true, - ..Default::default() - }]); - } else if value.is_object() { - tokenizer.add_special_tokens(&[AddedToken { - content: value["content"].as_str().unwrap().into(), - special: true, - single_word: value["single_word"].as_bool().unwrap(), - lstrip: value["lstrip"].as_bool().unwrap(), - rstrip: value["rstrip"].as_bool().unwrap(), - normalized: value["normalized"].as_bool().unwrap(), - }]); - } - } - } - Ok(tokenizer.into()) -} - -pub fn normalize(v: &[f32]) -> Vec { - let norm = (v.par_iter().map(|val| val * val).sum::()).sqrt(); - let epsilon = 1e-12; - - // We add the super-small epsilon to avoid dividing by zero - v.par_iter().map(|&val| val / (norm + epsilon)).collect() -} +use std::path::PathBuf; /// Public function to read a file to bytes. /// To be used when loading local model files. @@ -118,4 +14,34 @@ pub fn read_file_to_bytes(file: &PathBuf) -> Result, AIProxyError> { let mut buffer = Vec::with_capacity(file_size); file.read_to_end(&mut buffer).map_err(|_| AIProxyError::ModelTokenizerLoadError)?; Ok(buffer) +} + +pub struct HFConfigReader { + model_repo: ApiRepo, + cache: HashMap> +} + +impl HFConfigReader { + pub fn new(model_repo: ApiRepo) -> Self { + Self { + model_repo, + cache: HashMap::new(), + } + } + + pub fn read(&mut self, config_name: &str) -> Result { + if let Some(value) = self.cache.get(config_name) { + return value.clone(); + } + let file = self.model_repo.get(config_name).map_err(|_| AIProxyError::ModelConfigLoadError{ + message: format!("failed to parse {}", config_name), + })?; + let contents = read_file_to_bytes(&file).unwrap(); + let value: serde_json::Value = serde_json::from_slice(&contents).map_err( + |_| AIProxyError::ModelConfigLoadError{ + message: format!("failed to parse {}", config_name), + })?; + self.cache.insert(config_name.to_string(), Ok(value.clone())); + Ok(value) + } } \ No newline at end of file diff --git a/ahnlich/ai/src/engine/ai/providers/ort_text_helper.rs b/ahnlich/ai/src/engine/ai/providers/ort_text_helper.rs new file mode 100644 index 00000000..95b77f37 --- /dev/null +++ b/ahnlich/ai/src/engine/ai/providers/ort_text_helper.rs @@ -0,0 +1,116 @@ +// This script was adapted from FastEmbed +// https://github.com/Anush008/fastembed-rs/blob/474d4e62c87666781b580ffc076b8475b693fc34/src/common.rs +use hf_hub::api::sync::ApiRepo; +use rayon::prelude::*; +use tokenizers::decoders::bpe::BPEDecoder; +use tokenizers::{AddedToken, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams}; +use crate::engine::ai::providers::ort_helper; +use crate::error::AIProxyError; + + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TokenizerFiles { + pub tokenizer_file: Vec, + pub config_file: Vec, + pub special_tokens_map_file: Vec, + pub tokenizer_config_file: Vec, +} + +/// The procedure for loading tokenizer files from the hugging face hub is separated +/// from the main load_tokenizer function (which is expecting bytes, from any source). +pub fn load_tokenizer_artifacts_hf_hub(model_repo: &ApiRepo, max_length: usize) -> Result { + let tokenizer_files: TokenizerFiles = get_tokenizer_artifacts_hf_hub(model_repo)?; + load_tokenizer(tokenizer_files, max_length) +} + +pub fn get_tokenizer_artifacts_hf_hub(model_repo: &ApiRepo) -> Result { + Ok(TokenizerFiles { + tokenizer_file: ort_helper::read_file_to_bytes(&model_repo.get("tokenizer.json") + .map_err(|_| AIProxyError::ModelTokenizerLoadError)?) + .map_err(|_| AIProxyError::ModelTokenizerLoadError)?, + config_file: ort_helper::read_file_to_bytes(&model_repo.get("config.json") + .map_err(|_| AIProxyError::ModelTokenizerLoadError)?) + .map_err(|_| AIProxyError::ModelTokenizerLoadError)?, + special_tokens_map_file: ort_helper::read_file_to_bytes(&model_repo.get("special_tokens_map.json") + .map_err(|_| AIProxyError::ModelTokenizerLoadError)?) + .map_err(|_| AIProxyError::ModelTokenizerLoadError)?, + tokenizer_config_file: ort_helper::read_file_to_bytes(&model_repo.get("tokenizer_config.json") + .map_err(|_| AIProxyError::ModelTokenizerLoadError)?) + .map_err(|_| AIProxyError::ModelTokenizerLoadError)?, + }) +} + +/// Function can be called directly from the try_new_from_user_defined function (providing file bytes) +/// +/// Or indirectly from the try_new function via load_tokenizer_hf_hub (converting HF files to bytes) +pub fn load_tokenizer(tokenizer_files: TokenizerFiles, max_length: usize) -> Result { + // Serialise each tokenizer file + let config: serde_json::Value = + serde_json::from_slice(&tokenizer_files.config_file).map_err(|_| AIProxyError::ModelTokenizerLoadError)?; + let special_tokens_map: serde_json::Value = + serde_json::from_slice(&tokenizer_files.special_tokens_map_file).map_err(|_| AIProxyError::ModelTokenizerLoadError)?; + let tokenizer_config: serde_json::Value = + serde_json::from_slice(&tokenizer_files.tokenizer_config_file).map_err(|_| AIProxyError::ModelTokenizerLoadError)?; + let mut tokenizer: tokenizers::Tokenizer = + tokenizers::Tokenizer::from_bytes(tokenizer_files.tokenizer_file).map_err(|_| AIProxyError::ModelTokenizerLoadError)?; + + //For BGEBaseSmall, the model_max_length value is set to 1000000000000000019884624838656. Which fits in a f64 + let model_max_length = tokenizer_config["model_max_length"] + .as_f64() + .ok_or(AIProxyError::ModelTokenizerLoadError)? as f32; + let max_length = max_length.min(model_max_length as usize); + let pad_id = config["pad_token_id"].as_u64().unwrap_or(0) as u32; + let pad_token = tokenizer_config["pad_token"] + .as_str() + .ok_or(AIProxyError::ModelTokenizerLoadError)? + .into(); + + let mut tokenizer = tokenizer + .with_padding(Some(PaddingParams { + // TODO: the user should able to choose the padding strategy + strategy: PaddingStrategy::BatchLongest, + pad_token, + pad_id, + ..Default::default() + })) + .with_truncation(Some(TruncationParams { + max_length, + ..Default::default() + })) + .map_err(|_| AIProxyError::ModelTokenizerLoadError)? + .clone(); + if let serde_json::Value::Object(root_object) = special_tokens_map { + for (_, value) in root_object.iter() { + if value.is_string() { + tokenizer.add_special_tokens(&[AddedToken { + content: value.as_str().unwrap().into(), + special: true, + ..Default::default() + }]); + } else if value.is_object() { + tokenizer.add_special_tokens(&[AddedToken { + content: value["content"].as_str().unwrap().into(), + special: true, + single_word: value["single_word"].as_bool().unwrap(), + lstrip: value["lstrip"].as_bool().unwrap(), + rstrip: value["rstrip"].as_bool().unwrap(), + normalized: value["normalized"].as_bool().unwrap(), + }]); + } + } + } + + let decoder = BPEDecoder::new("".to_string()); + tokenizer.with_decoder(Some(decoder)); + + Ok(tokenizer.into()) +} + +pub fn normalize(v: &[f32]) -> Vec { + let norm = (v.par_iter().map(|val| val * val).sum::()).sqrt(); + let epsilon = 1e-12; + + // We add the super-small epsilon to avoid dividing by zero + v.par_iter().map(|&val| val / (norm + epsilon)).collect() +} + diff --git a/ahnlich/ai/src/engine/ai/providers/processors/center_crop.rs b/ahnlich/ai/src/engine/ai/providers/processors/center_crop.rs new file mode 100644 index 00000000..968787a2 --- /dev/null +++ b/ahnlich/ai/src/engine/ai/providers/processors/center_crop.rs @@ -0,0 +1,122 @@ +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; +use crate::engine::ai::models::ImageArray; +use crate::engine::ai::providers::processors::{CONV_NEXT_FEATURE_EXTRACTOR_CENTER_CROP_THRESHOLD, Processor, ProcessorData}; +use crate::error::AIProxyError; + +pub struct CenterCrop { + crop_size: (u32, u32), // (width, height) + process: bool +} + +impl TryFrom<&serde_json::Value> for CenterCrop { + type Error = AIProxyError; + + fn try_from(config: &serde_json::Value) -> Result { + if !config["do_center_crop"].as_bool().unwrap_or(false) { + return Ok( + Self { + crop_size: (0, 0), + process: false + } + ); + } + + let image_processor_type = config["image_processor_type"].as_str() + .unwrap_or("CLIPImageProcessor"); + + match image_processor_type { + "CLIPImageProcessor" => { + let crop_size = &config["crop_size"]; + let has_crop_size = crop_size.is_object() || crop_size.is_u64(); + if !has_crop_size { + return Err(AIProxyError::ModelConfigLoadError { + message: + "The key 'crop_size' is missing from the configuration or has the wrong type; \ + it should be an integer or an object containing 'height' and 'width' mappings.".to_string(), + }); + } + let (width, height); + if crop_size.is_object() { + height = crop_size["height"].as_u64().ok_or_else(|| AIProxyError::ModelConfigLoadError { + message: "The key 'height' is missing from the ['crop_size'] section of \ + the configuration or has the wrong type; it should be an integer".to_string(), + })? as u32; + width = crop_size["width"].as_u64().ok_or_else(|| AIProxyError::ModelConfigLoadError { + message: "The key 'width' is missing from the ['crop_size'] section of \ + the configuration or has the wrong type; it should be an integer".to_string(), + })? as u32; + } else { + let size = crop_size.as_u64().expect("It will always be an integer here.") as u32; + width = size; + height = size; + } + + Ok(Self { + crop_size: (width, height), + process: true + }) + }, + "ConvNextFeatureExtractor" => { + let size = &config["size"]; + if !size.is_object() { + return Err(AIProxyError::ModelConfigLoadError { + message: "The key 'size' is missing from the configuration or has the wrong type; it should be an object containing a 'shortest_edge' mapping.".to_string(), + }); + } + let shortest_edge = size["shortest_edge"].as_u64() + .ok_or_else(|| AIProxyError::ModelConfigLoadError { + message: "The key 'shortest_edge' is missing from the ['size'] section of \ + the configuration or has the wrong type; it should be an integer".to_string(), + })? as u32; + Ok(Self { + crop_size: (shortest_edge, shortest_edge), + process: shortest_edge < CONV_NEXT_FEATURE_EXTRACTOR_CENTER_CROP_THRESHOLD + }) + }, + _ => Err(AIProxyError::ModelConfigLoadError { + message: format!("The key 'image_processor_type' in the configuration has the wrong value: {}; \ + it should be either 'CLIPImageProcessor' or 'ConvNextFeatureExtractor'.", image_processor_type).to_string(), + }) + } + } +} + +impl Processor for CenterCrop { + fn process(&self, data: ProcessorData) -> Result { + match data { + ProcessorData::ImageArray(image_array) => { + let processed = image_array.par_iter().map(|image| { + if !self.process { + return Ok(image.clone()); + } + + let (width, height) = image.image_dim(); + let width = width.get() as u32; + let height = height.get() as u32; + let (crop_width, crop_height) = self.crop_size; + if crop_width == width && crop_height == height { + let image = image.to_owned(); + Ok(image) + } else if crop_width <= width || crop_height <= height { + let x = (width - crop_width) / 2; + let y = (height - crop_height) / 2; + let image = image.crop(x, y, crop_width, crop_height)?; + Ok(image) + } else { + // The Fastembed-rs implementation pads the image with zeros, but that does not make + // sense to me (HAKSOAT), just as it does not make sense to "crop" to a bigger size. + // This is why I am going with resize, it is also important to note that + // I expect these cases to be minor because Resize will often be called before Center Crop anyway. + let image = image.resize(crop_width, crop_height, None)?; + Ok(image) + } + }) + .collect::, AIProxyError>>(); + Ok(ProcessorData::ImageArray(processed?)) + }, + _ => Err(AIProxyError::CenterCropError { + message: "CenterCrop process failed. Expected ImageArray, got NdArray3C".to_string(), + }) + } + } +} \ No newline at end of file diff --git a/ahnlich/ai/src/engine/ai/providers/processors/imagearray_to_ndarray.rs b/ahnlich/ai/src/engine/ai/providers/processors/imagearray_to_ndarray.rs new file mode 100644 index 00000000..fa53b1bb --- /dev/null +++ b/ahnlich/ai/src/engine/ai/providers/processors/imagearray_to_ndarray.rs @@ -0,0 +1,29 @@ +use ndarray::{ArrayView, Ix3}; +use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator}; +use crate::engine::ai::providers::processors::{Processor, ProcessorData}; +use crate::error::AIProxyError; + +pub struct ImageArrayToNdArray; + +impl Processor for ImageArrayToNdArray { + fn process(&self, data: ProcessorData) -> Result { + match data { + ProcessorData::ImageArray(mut arrays) => { + let array_views: Vec> = arrays + .par_iter_mut() + .map(|image_arr| { + image_arr.onnx_transform(); + image_arr.view() + }) + .collect(); + + let pixel_values_array = ndarray::stack(ndarray::Axis(0), &array_views) + .map_err(|e| AIProxyError::EmbeddingShapeError(e.to_string()))?; + Ok(ProcessorData::NdArray3C(pixel_values_array)) + } + _ => Err(AIProxyError::ImageArrayToNdArrayError { + message: "ImageArrayToNdArray failed. Expected ImageArray, got NdArray3C".to_string(), + }), + } + } +} diff --git a/ahnlich/ai/src/engine/ai/providers/processors/mod.rs b/ahnlich/ai/src/engine/ai/providers/processors/mod.rs new file mode 100644 index 00000000..4f7518f6 --- /dev/null +++ b/ahnlich/ai/src/engine/ai/providers/processors/mod.rs @@ -0,0 +1,21 @@ +use crate::engine::ai::models::ImageArray; +use crate::error::AIProxyError; +use ndarray::{Array, Ix4}; + +pub mod normalize; +pub mod resize; +pub mod imagearray_to_ndarray; +pub mod center_crop; +pub mod rescale; +pub mod preprocessor; + +pub const CONV_NEXT_FEATURE_EXTRACTOR_CENTER_CROP_THRESHOLD: u32 = 384; + +pub trait Processor: Send + Sync { + fn process(&self, data: ProcessorData) -> Result; +} + +pub enum ProcessorData { + ImageArray(Vec), + NdArray3C(Array) +} diff --git a/ahnlich/ai/src/engine/ai/providers/processors/normalize.rs b/ahnlich/ai/src/engine/ai/providers/processors/normalize.rs new file mode 100644 index 00000000..adea122f --- /dev/null +++ b/ahnlich/ai/src/engine/ai/providers/processors/normalize.rs @@ -0,0 +1,78 @@ +use crate::error::AIProxyError; +use crate::engine::ai::providers::processors::{Processor, ProcessorData}; +use ndarray::Array; +use std::ops::{Div, Sub}; + +pub struct Normalize { + mean: Vec, + std: Vec, + process: bool +} + +impl TryFrom<&serde_json::Value> for Normalize { + type Error = AIProxyError; + + fn try_from(config: &serde_json::Value) -> Result { + if !config["do_normalize"].as_bool().unwrap_or(false) { + return Ok( + Self { + mean: vec![], + std: vec![], + process: false + } + ); + } + + fn get_array(value: &serde_json::Value, key: &str) -> Result, AIProxyError> { + let field = value.get(key) + .ok_or_else(|| AIProxyError::ModelConfigLoadError { + message: format!("The key '{}' is missing from the configuration.", key), + })?; + serde_json::from_value(field.to_owned()).map_err(|_| AIProxyError::ModelConfigLoadError { + message: format!("The key '{}' in the configuration must be an array of floats.", key), + }) + } + + let mean = get_array(config, "image_mean")?; + let std = get_array(config, "image_std")?; + Ok(Self { mean, std, process: true }) + } +} + +impl Processor for Normalize { + fn process(&self, data: ProcessorData) -> Result { + if !self.process { + return Ok(data); + } + + match data { + ProcessorData::NdArray3C(array) => { + let mean = Array::from_vec(self.mean.clone()) + .into_shape_with_order((3, 1, 1)) + .unwrap(); + let std = Array::from_vec(self.std.clone()) + .into_shape_with_order((3, 1, 1)) + .unwrap(); + + let shape = array.shape().to_vec(); + match shape.as_slice() { + [b, c, h, w] => { + let mean_broadcast = mean.broadcast((*b, *c, *h, *w)).expect("Broadcast will always succeed."); + let std_broadcast = std.broadcast((*b, *c, *h, *w)).expect("Broadcast will always succeed."); + let array_normalized = array + .sub(mean_broadcast) + .div(std_broadcast); + Ok(ProcessorData::NdArray3C(array_normalized)) + } + _ => Err(AIProxyError::ImageNormalizationError { + message: format!("Image normalization failed due to invalid shape for image array; \ + expected 4 dimensions, got {} dimensions.", shape.len()), + }), + } + } + _ => Err(AIProxyError::ImageNormalizationError { + message: "Expected NdArray3C, got ImageArray".to_string(), + }), + } + } +} \ No newline at end of file diff --git a/ahnlich/ai/src/engine/ai/providers/processors/preprocessor.rs b/ahnlich/ai/src/engine/ai/providers/processors/preprocessor.rs new file mode 100644 index 00000000..da3d83b1 --- /dev/null +++ b/ahnlich/ai/src/engine/ai/providers/processors/preprocessor.rs @@ -0,0 +1,127 @@ +use std::iter; +use hf_hub::api::sync::ApiRepo; +use ndarray::{Array, Ix4}; +use tokenizers::Tokenizer; +use crate::engine::ai::models::ImageArray; +use crate::engine::ai::providers::ort_helper::HFConfigReader; +use crate::engine::ai::providers::processors::center_crop::CenterCrop; +use crate::engine::ai::providers::processors::imagearray_to_ndarray::ImageArrayToNdArray; +use crate::engine::ai::providers::processors::normalize::Normalize; +use crate::engine::ai::providers::processors::{Processor, ProcessorData}; +use crate::engine::ai::providers::processors::rescale::Rescale; +use crate::engine::ai::providers::processors::resize::Resize; +use crate::error::AIProxyError; + +pub struct ImagePreprocessorFiles { + resize: Option, + normalize: Option, + rescale: Option, + center_crop: Option, +} + +impl ImagePreprocessorFiles { + pub fn iter(&self) -> impl Iterator { + iter::empty() + .chain(self.resize.as_ref().map( + |n| ("resize", n.as_str()))) + .chain(self.normalize.as_ref().map( + |n| ("normalize", n.as_str()))) + .chain(self.rescale.as_ref().map( + |n| ("rescale", n.as_str()))) + .chain(self.center_crop.as_ref().map( + |n| ("center_crop", n.as_str()))) + } +} + +impl Default for ImagePreprocessorFiles { + fn default() -> Self { + Self { + normalize: Some("preprocessor_config.json".to_string()), + resize: Some("preprocessor_config.json".to_string()), + rescale: Some("preprocessor_config.json".to_string()), + center_crop: Some("preprocessor_config.json".to_string()), + } + } +} + +#[derive(Default)] +pub struct ORTImagePreprocessor { + imagearray_to_ndarray: Option>, + normalize: Option>, + resize: Option>, + rescale: Option>, + center_crop: Option>, +} + +impl ORTImagePreprocessor { + pub fn iter(&self) -> impl Iterator)> { + iter::empty() + .chain(self.resize.as_ref().map( + |f| ("resize", f))) + .chain(self.center_crop.as_ref().map( + |f| ("center_crop", f))) + .chain(self.imagearray_to_ndarray.as_ref().map( + |f| ("imagearray_to_ndarray", f))) + .chain(self.rescale.as_ref().map( + |f| ("rescale", f))) + .chain(self.normalize.as_ref().map( + |f| ("normalize", f))) + } + + pub fn load(&mut self, model_repo: ApiRepo, processor_files: ImagePreprocessorFiles) -> Result<(), AIProxyError> { + let mut type_and_configs: Vec<(&str, Option)> = vec![ + ("imagearray_to_ndarray", None) + ]; + + let mut config_reader = HFConfigReader::new(model_repo); + for data in processor_files.iter() { + type_and_configs.push((data.0, Some(config_reader.read(data.1)?))); + } + for (processor_type, config) in type_and_configs { + match processor_type { + "imagearray_to_ndarray" => { + self.imagearray_to_ndarray = Some(Box::new(ImageArrayToNdArray)); + } + "resize" => { + self.resize = Some(Box::new(Resize::try_from(&config.expect("Config exists"))?)); + } + "normalize" => { + self.normalize = Some(Box::new(Normalize::try_from(&config.expect("Config exists"))?)); + } + "rescale" => { + self.rescale = Some(Box::new(Rescale::try_from(&config.expect("Config exists"))?)); + } + "center_crop" => { + self.center_crop = Some(Box::new(CenterCrop::try_from(&config.expect("Config exists"))?)); + } + _ => return Err(AIProxyError::ModelProviderPreprocessingError( + format!("The {} operation not found in ImagePreprocessor.", processor_type) + )) + } + } + Ok(()) + } + + pub fn process(&self, data: Vec) -> Result, AIProxyError> { + let mut data = ProcessorData::ImageArray(data); + for (_, processor) in self.iter() { + data = processor.process(data)?; + } + match data { + ProcessorData::NdArray3C(array) => Ok(array), + _ => Err(AIProxyError::ModelProviderPreprocessingError( + "Expected NdArray after processing".to_string() + )) + } + } +} + +pub enum ORTPreprocessor { + Image(ORTImagePreprocessor), + Text(ORTTextPreprocessor), +} + +pub struct ORTTextPreprocessor { + pub tokenizer: Tokenizer, +} \ No newline at end of file diff --git a/ahnlich/ai/src/engine/ai/providers/processors/rescale.rs b/ahnlich/ai/src/engine/ai/providers/processors/rescale.rs new file mode 100644 index 00000000..5b3ed2fb --- /dev/null +++ b/ahnlich/ai/src/engine/ai/providers/processors/rescale.rs @@ -0,0 +1,45 @@ +use crate::engine::ai::providers::processors::{Processor, ProcessorData}; +use crate::error::AIProxyError; + +pub struct Rescale { + scale: f32, + process: bool +} + +impl TryFrom<&serde_json::Value> for Rescale { + type Error = AIProxyError; + + fn try_from(config: &serde_json::Value) -> Result { + if !config["do_rescale"].as_bool().unwrap_or(true) { + return Ok( + Self { + scale: 0f32, + process: false + } + ); + } + + let default_scale = 1.0/255.0; + let scale = config["rescale_factor"].as_f64().unwrap_or(default_scale) as f32; + Ok(Self { scale, process: true }) + } +} + +impl Processor for Rescale { + fn process(&self, data: ProcessorData) -> Result { + if !self.process { + return Ok(data); + } + + match data { + ProcessorData::NdArray3C(array) => { + let mut array = array; + array *= self.scale; + Ok(ProcessorData::NdArray3C(array)) + }, + _ => Err(AIProxyError::RescaleError { + message: "Rescale process failed. Expected NdArray3C, got ImageArray".to_string(), + }), + } + } +} \ No newline at end of file diff --git a/ahnlich/ai/src/engine/ai/providers/processors/resize.rs b/ahnlich/ai/src/engine/ai/providers/processors/resize.rs new file mode 100644 index 00000000..a670db38 --- /dev/null +++ b/ahnlich/ai/src/engine/ai/providers/processors/resize.rs @@ -0,0 +1,108 @@ +use image::imageops::FilterType; +use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator}; +use crate::engine::ai::models::ImageArray; +use crate::engine::ai::providers::processors::{CONV_NEXT_FEATURE_EXTRACTOR_CENTER_CROP_THRESHOLD, Processor, ProcessorData}; +use crate::error::AIProxyError; + +pub struct Resize { + size: (u32, u32), // (width, height) + resample: FilterType, + process: bool +} + +impl TryFrom<&serde_json::Value> for Resize { + type Error = AIProxyError; + + fn try_from(config: &serde_json::Value) -> Result { + // let config = SafeValue::new(config.to_owned()); + if !config["do_resize"].as_bool().unwrap_or(false) { + return Ok( + Self { + size: (0, 0), + resample: FilterType::CatmullRom, + process: false + } + ); + } + + let image_processor_type = config["image_processor_type"].as_str() + .unwrap_or("CLIPImageProcessor"); + + let size = &config["size"]; + if !size.is_object() { + return Err(AIProxyError::ModelConfigLoadError { + message: "The key 'size' is missing from the configuration or has the wrong type; \ + it should be an object containing a 'shortest_edge' mapping or a 'width' and 'height' mapping.".to_string(), + }); + } + + let (width, height): (u32, u32); + + let shortest_edge = &size["shortest_edge"]; + let size_width = &size["width"]; + let size_height = &size["height"]; + let has_value = shortest_edge.is_u64() || + (size_width.is_u64() && size_height.is_u64() + && image_processor_type == "CLIPImageProcessor"); + if !has_value { + return Err(AIProxyError::ModelConfigLoadError { + message: "The ['size'] section of the configuration must contain either a \ + 'shortest_edge' mapping or 'width' and 'height' mappings (when \ + 'image_processor_type' is 'CLIPImageProcessor'); they should be \ + integers.".to_string(), + }); + } + + if shortest_edge.is_u64() { + width = shortest_edge.as_u64().expect("It will always be an integer here.") as u32; + height = width; + } else { + width = size_width.as_u64().expect("It will always be an integer here.") as u32; + height = size_height.as_u64().expect("It will always be an integer here.") as u32; + } + + match image_processor_type { + "CLIPImageProcessor" => { + Ok(Self { size: (width, height), resample: FilterType::CatmullRom, process: true }) + }, + "ConvNextFeatureExtractor" => { + if width >= CONV_NEXT_FEATURE_EXTRACTOR_CENTER_CROP_THRESHOLD { + Ok(Self { size: (width, height), resample: FilterType::CatmullRom, process: true + }) + } else { + let default_crop_pct = 0.875; + let crop_pct = config["crop_pct"].as_f64().unwrap_or(default_crop_pct) as f32; + let upsampled_edge = (width as f32 / crop_pct) as u32; + Ok(Self { size: (upsampled_edge, upsampled_edge), resample: FilterType::CatmullRom, + process: true }) + } + }, + _ => Err(AIProxyError::ModelConfigLoadError { + message: format!("Resize init failed. image_processor_type {} not supported", image_processor_type), + }) + } + } +} + +impl Processor for Resize { + fn process(&self, data: ProcessorData) -> Result { + match data { + ProcessorData::ImageArray(mut arrays) => { + let processed = arrays.par_iter_mut() + .map(|image| { + if !self.process { + return Ok(image.clone()); + } + + let image = image.resize(self.size.0, self.size.1, Some(self.resample))?; + Ok(image) + }) + .collect::, AIProxyError>>(); + Ok(ProcessorData::ImageArray(processed?)) + } + _ => Err(AIProxyError::ImageArrayToNdArrayError { + message: "Resize failed. Expected ImageArray, got NdArray3C".to_string(), + }), + } + } +} \ No newline at end of file diff --git a/ahnlich/ai/src/engine/store.rs b/ahnlich/ai/src/engine/store.rs index d2f02215..745aff5d 100644 --- a/ahnlich/ai/src/engine/store.rs +++ b/ahnlich/ai/src/engine/store.rs @@ -283,7 +283,7 @@ impl AIStoreHandler { let store = self.get(store_name)?; let mut store_keys = model_manager .handle_request( - &store.index_model, + &store.query_model, vec![store_input], preprocess_action, InputAction::Query, diff --git a/ahnlich/ai/src/error.rs b/ahnlich/ai/src/error.rs index 97a6479a..7b77b215 100644 --- a/ahnlich/ai/src/error.rs +++ b/ahnlich/ai/src/error.rs @@ -8,7 +8,7 @@ use tokio::sync::oneshot::error::RecvError; use crate::engine::ai::models::InputAction; -#[derive(Error, Debug, PartialEq, Eq)] +#[derive(Error, Clone, Debug, PartialEq, Eq)] pub enum AIProxyError { #[error("Store {0} not found")] StoreNotFound(StoreName), @@ -74,6 +74,32 @@ pub enum AIProxyError { model_name: String }, + #[error("Invalid operation [{operation}] on model [{model_name}]")] + AIModelInvalidOperation { + operation: String, + model_name: String + }, + + #[error("Normalization error: [{message}]")] + ImageNormalizationError { + message: String + }, + + #[error("ImageArray to NdArray conversion error: [{message}]")] + ImageArrayToNdArrayError { + message: String + }, + + #[error("Rescale error: [{message}]")] + RescaleError { + message: String + }, + + #[error("Center crop error: [{message}]")] + CenterCropError { + message: String + }, + // TODO: Add SendError from mpsc::Sender into this variant #[error("Error sending request to model thread")] AIModelThreadSendError, @@ -95,6 +121,9 @@ pub enum AIProxyError { #[error("Bytes could not be successfully decoded into an image.")] ImageBytesDecodeError, + #[error("Image could not be successfully encoded into bytes.")] + ImageBytesEncodeError, + #[error( "Image can't have zero value in any dimension. Found height: {height}, width: {width}" )] @@ -103,6 +132,9 @@ pub enum AIProxyError { #[error("Image could not be resized.")] ImageResizeError, + #[error("Image could not be cropped.")] + ImageCropError, + #[error("Model provider failed on preprocessing the input {0}")] ModelProviderPreprocessingError(String), @@ -122,7 +154,12 @@ pub enum AIProxyError { DelKeyError, #[error("Tokenizer for model failed on loading.")] - ModelTokenizerLoadError + ModelTokenizerLoadError, + + #[error("Unable to load config: [{message}].")] + ModelConfigLoadError{ + message: String + } } impl From for AIProxyError { diff --git a/ahnlich/ai/src/manager/mod.rs b/ahnlich/ai/src/manager/mod.rs index 603ac62f..059c0a31 100644 --- a/ahnlich/ai/src/manager/mod.rs +++ b/ahnlich/ai/src/manager/mod.rs @@ -111,12 +111,12 @@ impl ModelThread { input: String, string_action: StringAction, ) -> Result { - let max_token_size = self.model.max_input_token().unwrap_or_else(|| { - panic!( - "`max_input_token()` is not supported for model: {:?}", - self.model.model_name() - ) - }); + let max_token_size = self.model.max_input_token().ok_or_else(|| { + AIProxyError::AIModelInvalidOperation { + model_name: self.model.model_name(), + operation: "[max_input_token] function".to_string() + } + })?; if self.model.input_type() != AIStoreInputType::RawString { return Err(AIProxyError::TokenTruncationNotSupported); @@ -175,7 +175,6 @@ impl ModelThread { expected_dimensions: (expected_width.into(), expected_height.into()), }); } else { - let input = input.resize(expected_width, expected_height)?; return Ok(input); } } diff --git a/ahnlich/ai/src/server/task.rs b/ahnlich/ai/src/server/task.rs index 65a99021..b578dc9a 100644 --- a/ahnlich/ai/src/server/task.rs +++ b/ahnlich/ai/src/server/task.rs @@ -338,46 +338,47 @@ impl AhnlichProtocol for AIProxyTask { preprocess, ) .await; - if let Ok(store_key) = repr { - match self - .db_client - .get_sim_n( - store, - store_key, - closest_n, - algorithm, - condition, - parent_id.clone(), - ) - .await - { - Ok(res) => { - if let ServerResponse::GetSimN(response) = res { - let (store_key_input, similarities): (Vec<_>, Vec<_>) = - response - .into_par_iter() - .map(|(a, b, c)| ((a, b), c)) - .unzip(); - Ok(AIServerResponse::GetSimN( - self.store_handler - .store_key_val_to_store_input_val(store_key_input) - .into_par_iter() - .zip(similarities.into_par_iter()) - .map(|((a, b), c)| (a, b, c)) - .collect(), - )) - } else { - Err(AIProxyError::UnexpectedDBResponse(format!("{:?}", res)) - .to_string()) + match repr { + Ok(store_key) => { + match self + .db_client + .get_sim_n( + store, + store_key, + closest_n, + algorithm, + condition, + parent_id.clone(), + ) + .await + { + Ok(res) => { + if let ServerResponse::GetSimN(response) = res { + let (store_key_input, similarities): (Vec<_>, Vec<_>) = + response + .into_par_iter() + .map(|(a, b, c)| ((a, b), c)) + .unzip(); + Ok(AIServerResponse::GetSimN( + self.store_handler + .store_key_val_to_store_input_val(store_key_input) + .into_par_iter() + .zip(similarities.into_par_iter()) + .map(|((a, b), c)| (a, b, c)) + .collect(), + )) + } else { + Err(AIProxyError::UnexpectedDBResponse(format!("{:?}", res)) + .to_string()) + } } + Err(err) => Err(format!("{err}")), } - Err(err) => Err(format!("{err}")), } - } else { - Err( - AIProxyError::StandardError("Failed to get store".to_string()) + Err(err) => Err( + AIProxyError::StandardError(err.to_string()) .to_string(), - ) + ), } } AIQuery::PurgeStores => { diff --git a/ahnlich/types/src/ai/preprocess.rs b/ahnlich/types/src/ai/preprocess.rs index 7c0403fc..248e8c3a 100644 --- a/ahnlich/types/src/ai/preprocess.rs +++ b/ahnlich/types/src/ai/preprocess.rs @@ -8,6 +8,7 @@ use std::fmt; pub enum StringAction { TruncateIfTokensExceed, ErrorIfTokensExceed, + ModelPreprocessing } /// The action to be performed if the image dimensions is larger than the maximum size a @@ -16,6 +17,7 @@ pub enum StringAction { pub enum ImageAction { ResizeImage, ErrorIfDimensionsMismatch, + ModelPreprocessing } #[derive(Copy, Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)] diff --git a/ahnlich/utils/src/cli.rs b/ahnlich/utils/src/cli.rs index 875ae3c9..28dea123 100644 --- a/ahnlich/utils/src/cli.rs +++ b/ahnlich/utils/src/cli.rs @@ -83,7 +83,7 @@ impl Default for CommandLineConfig { enable_tracing: false, otel_endpoint: None, - log_level: String::from("info"), + log_level: String::from("info,hf_hub=warn"), maximum_clients: 1000, threadpool_size: 16, } diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_query.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_query.py index 96b352d9..37fb48e5 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_query.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_query.py @@ -1,8 +1,10 @@ # pyre-strict -from dataclasses import dataclass import typing -from ahnlich_client_py.internals import serde_types as st +from dataclasses import dataclass + from ahnlich_client_py.internals import bincode +from ahnlich_client_py.internals import serde_types as st + class AIModel: VARIANTS = [] # type: typing.Sequence[typing.Type[AIModel]] @@ -11,10 +13,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIModel) @staticmethod - def bincode_deserialize(input: bytes) -> 'AIModel': + def bincode_deserialize(input: bytes) -> "AIModel": v, buffer = bincode.deserialize(input, AIModel) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -53,6 +55,12 @@ class AIModel__ClipVitB32Image(AIModel): INDEX = 5 # type: int pass +@dataclass(frozen=True) +class AIModel__ClipVitB32Text(AIModel): + INDEX = 6 # type: int + pass + + AIModel.VARIANTS = [ AIModel__AllMiniLML6V2, AIModel__AllMiniLML12V2, @@ -60,6 +68,7 @@ class AIModel__ClipVitB32Image(AIModel): AIModel__BGELargeEnV15, AIModel__Resnet50, AIModel__ClipVitB32Image, + AIModel__ClipVitB32Text, ] @@ -70,10 +79,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIQuery) @staticmethod - def bincode_deserialize(input: bytes) -> 'AIQuery': + def bincode_deserialize(input: bytes) -> "AIQuery": v, buffer = bincode.deserialize(input, AIQuery) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -140,7 +149,9 @@ class AIQuery__DropNonLinearAlgorithmIndex(AIQuery): class AIQuery__Set(AIQuery): INDEX = 7 # type: int store: str - inputs: typing.Sequence[typing.Tuple["StoreInput", typing.Dict[str, "MetadataValue"]]] + inputs: typing.Sequence[ + typing.Tuple["StoreInput", typing.Dict[str, "MetadataValue"]] + ] preprocess_action: "PreprocessAction" @@ -181,6 +192,7 @@ class AIQuery__Ping(AIQuery): INDEX = 13 # type: int pass + AIQuery.VARIANTS = [ AIQuery__CreateStore, AIQuery__GetPred, @@ -208,10 +220,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIServerQuery) @staticmethod - def bincode_deserialize(input: bytes) -> 'AIServerQuery': + def bincode_deserialize(input: bytes) -> "AIServerQuery": v, buffer = bincode.deserialize(input, AIServerQuery) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -222,10 +234,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIStoreInputType) @staticmethod - def bincode_deserialize(input: bytes) -> 'AIStoreInputType': + def bincode_deserialize(input: bytes) -> "AIStoreInputType": v, buffer = bincode.deserialize(input, AIStoreInputType) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -240,6 +252,7 @@ class AIStoreInputType__Image(AIStoreInputType): INDEX = 1 # type: int pass + AIStoreInputType.VARIANTS = [ AIStoreInputType__RawString, AIStoreInputType__Image, @@ -253,10 +266,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Algorithm) @staticmethod - def bincode_deserialize(input: bytes) -> 'Algorithm': + def bincode_deserialize(input: bytes) -> "Algorithm": v, buffer = bincode.deserialize(input, Algorithm) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -283,6 +296,7 @@ class Algorithm__KDTree(Algorithm): INDEX = 3 # type: int pass + Algorithm.VARIANTS = [ Algorithm__EuclideanDistance, Algorithm__DotProductSimilarity, @@ -298,10 +312,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ImageAction) @staticmethod - def bincode_deserialize(input: bytes) -> 'ImageAction': + def bincode_deserialize(input: bytes) -> "ImageAction": v, buffer = bincode.deserialize(input, ImageAction) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -316,6 +330,7 @@ class ImageAction__ErrorIfDimensionsMismatch(ImageAction): INDEX = 1 # type: int pass + ImageAction.VARIANTS = [ ImageAction__ResizeImage, ImageAction__ErrorIfDimensionsMismatch, @@ -329,10 +344,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, MetadataValue) @staticmethod - def bincode_deserialize(input: bytes) -> 'MetadataValue': + def bincode_deserialize(input: bytes) -> "MetadataValue": v, buffer = bincode.deserialize(input, MetadataValue) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -347,6 +362,7 @@ class MetadataValue__Image(MetadataValue): INDEX = 1 # type: int value: typing.Sequence[st.uint8] + MetadataValue.VARIANTS = [ MetadataValue__RawString, MetadataValue__Image, @@ -360,10 +376,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, NonLinearAlgorithm) @staticmethod - def bincode_deserialize(input: bytes) -> 'NonLinearAlgorithm': + def bincode_deserialize(input: bytes) -> "NonLinearAlgorithm": v, buffer = bincode.deserialize(input, NonLinearAlgorithm) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -372,6 +388,7 @@ class NonLinearAlgorithm__KDTree(NonLinearAlgorithm): INDEX = 0 # type: int pass + NonLinearAlgorithm.VARIANTS = [ NonLinearAlgorithm__KDTree, ] @@ -384,10 +401,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Predicate) @staticmethod - def bincode_deserialize(input: bytes) -> 'Predicate': + def bincode_deserialize(input: bytes) -> "Predicate": v, buffer = bincode.deserialize(input, Predicate) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -418,6 +435,7 @@ class Predicate__NotIn(Predicate): key: str value: typing.Sequence["MetadataValue"] + Predicate.VARIANTS = [ Predicate__Equals, Predicate__NotEquals, @@ -433,10 +451,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, PredicateCondition) @staticmethod - def bincode_deserialize(input: bytes) -> 'PredicateCondition': + def bincode_deserialize(input: bytes) -> "PredicateCondition": v, buffer = bincode.deserialize(input, PredicateCondition) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -457,6 +475,7 @@ class PredicateCondition__Or(PredicateCondition): INDEX = 2 # type: int value: typing.Tuple["PredicateCondition", "PredicateCondition"] + PredicateCondition.VARIANTS = [ PredicateCondition__Value, PredicateCondition__And, @@ -471,10 +490,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, PreprocessAction) @staticmethod - def bincode_deserialize(input: bytes) -> 'PreprocessAction': + def bincode_deserialize(input: bytes) -> "PreprocessAction": v, buffer = bincode.deserialize(input, PreprocessAction) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -489,6 +508,7 @@ class PreprocessAction__Image(PreprocessAction): INDEX = 1 # type: int value: "ImageAction" + PreprocessAction.VARIANTS = [ PreprocessAction__RawString, PreprocessAction__Image, @@ -502,10 +522,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StoreInput) @staticmethod - def bincode_deserialize(input: bytes) -> 'StoreInput': + def bincode_deserialize(input: bytes) -> "StoreInput": v, buffer = bincode.deserialize(input, StoreInput) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -520,6 +540,7 @@ class StoreInput__Image(StoreInput): INDEX = 1 # type: int value: typing.Sequence[st.uint8] + StoreInput.VARIANTS = [ StoreInput__RawString, StoreInput__Image, @@ -533,10 +554,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StringAction) @staticmethod - def bincode_deserialize(input: bytes) -> 'StringAction': + def bincode_deserialize(input: bytes) -> "StringAction": v, buffer = bincode.deserialize(input, StringAction) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -551,8 +572,8 @@ class StringAction__ErrorIfTokensExceed(StringAction): INDEX = 1 # type: int pass + StringAction.VARIANTS = [ StringAction__TruncateIfTokensExceed, StringAction__ErrorIfTokensExceed, ] - diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_response.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_response.py index 838cb40d..9e8cda32 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_response.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_response.py @@ -1,8 +1,10 @@ # pyre-strict -from dataclasses import dataclass import typing -from ahnlich_client_py.internals import serde_types as st +from dataclasses import dataclass + from ahnlich_client_py.internals import bincode +from ahnlich_client_py.internals import serde_types as st + class AIModel: VARIANTS = [] # type: typing.Sequence[typing.Type[AIModel]] @@ -11,10 +13,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIModel) @staticmethod - def bincode_deserialize(input: bytes) -> 'AIModel': + def bincode_deserialize(input: bytes) -> "AIModel": v, buffer = bincode.deserialize(input, AIModel) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -53,6 +55,7 @@ class AIModel__ClipVitB32Image(AIModel): INDEX = 5 # type: int pass + AIModel.VARIANTS = [ AIModel__AllMiniLML6V2, AIModel__AllMiniLML12V2, @@ -70,10 +73,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIServerResponse) @staticmethod - def bincode_deserialize(input: bytes) -> 'AIServerResponse': + def bincode_deserialize(input: bytes) -> "AIServerResponse": v, buffer = bincode.deserialize(input, AIServerResponse) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -116,13 +119,21 @@ class AIServerResponse__Set(AIServerResponse): @dataclass(frozen=True) class AIServerResponse__Get(AIServerResponse): INDEX = 6 # type: int - value: typing.Sequence[typing.Tuple[typing.Optional["StoreInput"], typing.Dict[str, "MetadataValue"]]] + value: typing.Sequence[ + typing.Tuple[typing.Optional["StoreInput"], typing.Dict[str, "MetadataValue"]] + ] @dataclass(frozen=True) class AIServerResponse__GetSimN(AIServerResponse): INDEX = 7 # type: int - value: typing.Sequence[typing.Tuple[typing.Optional["StoreInput"], typing.Dict[str, "MetadataValue"], "Similarity"]] + value: typing.Sequence[ + typing.Tuple[ + typing.Optional["StoreInput"], + typing.Dict[str, "MetadataValue"], + "Similarity", + ] + ] @dataclass(frozen=True) @@ -136,6 +147,7 @@ class AIServerResponse__CreateIndex(AIServerResponse): INDEX = 9 # type: int value: st.uint64 + AIServerResponse.VARIANTS = [ AIServerResponse__Unit, AIServerResponse__Pong, @@ -158,10 +170,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIServerResult) @staticmethod - def bincode_deserialize(input: bytes) -> 'AIServerResult': + def bincode_deserialize(input: bytes) -> "AIServerResult": v, buffer = bincode.deserialize(input, AIServerResult) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -176,10 +188,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIStoreInfo) @staticmethod - def bincode_deserialize(input: bytes) -> 'AIStoreInfo': + def bincode_deserialize(input: bytes) -> "AIStoreInfo": v, buffer = bincode.deserialize(input, AIStoreInfo) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -190,10 +202,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIStoreInputType) @staticmethod - def bincode_deserialize(input: bytes) -> 'AIStoreInputType': + def bincode_deserialize(input: bytes) -> "AIStoreInputType": v, buffer = bincode.deserialize(input, AIStoreInputType) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -208,6 +220,7 @@ class AIStoreInputType__Image(AIStoreInputType): INDEX = 1 # type: int pass + AIStoreInputType.VARIANTS = [ AIStoreInputType__RawString, AIStoreInputType__Image, @@ -223,10 +236,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ConnectedClient) @staticmethod - def bincode_deserialize(input: bytes) -> 'ConnectedClient': + def bincode_deserialize(input: bytes) -> "ConnectedClient": v, buffer = bincode.deserialize(input, ConnectedClient) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -237,10 +250,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, MetadataValue) @staticmethod - def bincode_deserialize(input: bytes) -> 'MetadataValue': + def bincode_deserialize(input: bytes) -> "MetadataValue": v, buffer = bincode.deserialize(input, MetadataValue) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -255,6 +268,7 @@ class MetadataValue__Image(MetadataValue): INDEX = 1 # type: int value: typing.Sequence[st.uint8] + MetadataValue.VARIANTS = [ MetadataValue__RawString, MetadataValue__Image, @@ -268,10 +282,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Result) @staticmethod - def bincode_deserialize(input: bytes) -> 'Result': + def bincode_deserialize(input: bytes) -> "Result": v, buffer = bincode.deserialize(input, Result) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -286,6 +300,7 @@ class Result__Err(Result): INDEX = 1 # type: int value: str + Result.VARIANTS = [ Result__Ok, Result__Err, @@ -304,10 +319,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerInfo) @staticmethod - def bincode_deserialize(input: bytes) -> 'ServerInfo': + def bincode_deserialize(input: bytes) -> "ServerInfo": v, buffer = bincode.deserialize(input, ServerInfo) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -318,10 +333,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerType) @staticmethod - def bincode_deserialize(input: bytes) -> 'ServerType': + def bincode_deserialize(input: bytes) -> "ServerType": v, buffer = bincode.deserialize(input, ServerType) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -336,6 +351,7 @@ class ServerType__AI(ServerType): INDEX = 1 # type: int pass + ServerType.VARIANTS = [ ServerType__Database, ServerType__AI, @@ -350,10 +366,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Similarity) @staticmethod - def bincode_deserialize(input: bytes) -> 'Similarity': + def bincode_deserialize(input: bytes) -> "Similarity": v, buffer = bincode.deserialize(input, Similarity) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -364,10 +380,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StoreInput) @staticmethod - def bincode_deserialize(input: bytes) -> 'StoreInput': + def bincode_deserialize(input: bytes) -> "StoreInput": v, buffer = bincode.deserialize(input, StoreInput) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -382,6 +398,7 @@ class StoreInput__Image(StoreInput): INDEX = 1 # type: int value: typing.Sequence[st.uint8] + StoreInput.VARIANTS = [ StoreInput__RawString, StoreInput__Image, @@ -397,10 +414,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StoreUpsert) @staticmethod - def bincode_deserialize(input: bytes) -> 'StoreUpsert': + def bincode_deserialize(input: bytes) -> "StoreUpsert": v, buffer = bincode.deserialize(input, StoreUpsert) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -413,10 +430,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, SystemTime) @staticmethod - def bincode_deserialize(input: bytes) -> 'SystemTime': + def bincode_deserialize(input: bytes) -> "SystemTime": v, buffer = bincode.deserialize(input, SystemTime) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -430,9 +447,8 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Version) @staticmethod - def bincode_deserialize(input: bytes) -> 'Version': + def bincode_deserialize(input: bytes) -> "Version": v, buffer = bincode.deserialize(input, Version) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v - diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/bincode/__init__.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/bincode/__init__.py index 4e5e0837..38cbd7ff 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/bincode/__init__.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/bincode/__init__.py @@ -1,16 +1,16 @@ # Copyright (c) Facebook, Inc. and its affiliates # SPDX-License-Identifier: MIT OR Apache-2.0 -import dataclasses import collections +import dataclasses import io import struct import typing from copy import copy from typing import get_type_hints -from ahnlich_client_py.internals import serde_types as st from ahnlich_client_py.internals import serde_binary as sb +from ahnlich_client_py.internals import serde_types as st # Maximum length in practice for sequences (e.g. in Java). MAX_LENGTH = (1 << 31) - 1 diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_query.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_query.py index b6120b2b..b281f346 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_query.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_query.py @@ -1,8 +1,10 @@ # pyre-strict -from dataclasses import dataclass import typing -from ahnlich_client_py.internals import serde_types as st +from dataclasses import dataclass + from ahnlich_client_py.internals import bincode +from ahnlich_client_py.internals import serde_types as st + class Algorithm: VARIANTS = [] # type: typing.Sequence[typing.Type[Algorithm]] @@ -11,10 +13,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Algorithm) @staticmethod - def bincode_deserialize(input: bytes) -> 'Algorithm': + def bincode_deserialize(input: bytes) -> "Algorithm": v, buffer = bincode.deserialize(input, Algorithm) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -41,6 +43,7 @@ class Algorithm__KDTree(Algorithm): INDEX = 3 # type: int pass + Algorithm.VARIANTS = [ Algorithm__EuclideanDistance, Algorithm__DotProductSimilarity, @@ -59,10 +62,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Array) @staticmethod - def bincode_deserialize(input: bytes) -> 'Array': + def bincode_deserialize(input: bytes) -> "Array": v, buffer = bincode.deserialize(input, Array) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -73,10 +76,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, MetadataValue) @staticmethod - def bincode_deserialize(input: bytes) -> 'MetadataValue': + def bincode_deserialize(input: bytes) -> "MetadataValue": v, buffer = bincode.deserialize(input, MetadataValue) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -91,6 +94,7 @@ class MetadataValue__Image(MetadataValue): INDEX = 1 # type: int value: typing.Sequence[st.uint8] + MetadataValue.VARIANTS = [ MetadataValue__RawString, MetadataValue__Image, @@ -104,10 +108,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, NonLinearAlgorithm) @staticmethod - def bincode_deserialize(input: bytes) -> 'NonLinearAlgorithm': + def bincode_deserialize(input: bytes) -> "NonLinearAlgorithm": v, buffer = bincode.deserialize(input, NonLinearAlgorithm) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -116,6 +120,7 @@ class NonLinearAlgorithm__KDTree(NonLinearAlgorithm): INDEX = 0 # type: int pass + NonLinearAlgorithm.VARIANTS = [ NonLinearAlgorithm__KDTree, ] @@ -128,10 +133,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Predicate) @staticmethod - def bincode_deserialize(input: bytes) -> 'Predicate': + def bincode_deserialize(input: bytes) -> "Predicate": v, buffer = bincode.deserialize(input, Predicate) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -162,6 +167,7 @@ class Predicate__NotIn(Predicate): key: str value: typing.Sequence["MetadataValue"] + Predicate.VARIANTS = [ Predicate__Equals, Predicate__NotEquals, @@ -177,10 +183,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, PredicateCondition) @staticmethod - def bincode_deserialize(input: bytes) -> 'PredicateCondition': + def bincode_deserialize(input: bytes) -> "PredicateCondition": v, buffer = bincode.deserialize(input, PredicateCondition) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -201,6 +207,7 @@ class PredicateCondition__Or(PredicateCondition): INDEX = 2 # type: int value: typing.Tuple["PredicateCondition", "PredicateCondition"] + PredicateCondition.VARIANTS = [ PredicateCondition__Value, PredicateCondition__And, @@ -215,10 +222,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Query) @staticmethod - def bincode_deserialize(input: bytes) -> 'Query': + def bincode_deserialize(input: bytes) -> "Query": v, buffer = bincode.deserialize(input, Query) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -337,6 +344,7 @@ class Query__Ping(Query): INDEX = 15 # type: int pass + Query.VARIANTS = [ Query__CreateStore, Query__GetKey, @@ -366,9 +374,8 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerQuery) @staticmethod - def bincode_deserialize(input: bytes) -> 'ServerQuery': + def bincode_deserialize(input: bytes) -> "ServerQuery": v, buffer = bincode.deserialize(input, ServerQuery) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v - diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_response.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_response.py index acd3baa1..d1d0a6c4 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_response.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_response.py @@ -1,8 +1,10 @@ # pyre-strict -from dataclasses import dataclass import typing -from ahnlich_client_py.internals import serde_types as st +from dataclasses import dataclass + from ahnlich_client_py.internals import bincode +from ahnlich_client_py.internals import serde_types as st + @dataclass(frozen=True) class Array: @@ -14,10 +16,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Array) @staticmethod - def bincode_deserialize(input: bytes) -> 'Array': + def bincode_deserialize(input: bytes) -> "Array": v, buffer = bincode.deserialize(input, Array) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -30,10 +32,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ConnectedClient) @staticmethod - def bincode_deserialize(input: bytes) -> 'ConnectedClient': + def bincode_deserialize(input: bytes) -> "ConnectedClient": v, buffer = bincode.deserialize(input, ConnectedClient) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -44,10 +46,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, MetadataValue) @staticmethod - def bincode_deserialize(input: bytes) -> 'MetadataValue': + def bincode_deserialize(input: bytes) -> "MetadataValue": v, buffer = bincode.deserialize(input, MetadataValue) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -62,6 +64,7 @@ class MetadataValue__Image(MetadataValue): INDEX = 1 # type: int value: typing.Sequence[st.uint8] + MetadataValue.VARIANTS = [ MetadataValue__RawString, MetadataValue__Image, @@ -75,10 +78,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Result) @staticmethod - def bincode_deserialize(input: bytes) -> 'Result': + def bincode_deserialize(input: bytes) -> "Result": v, buffer = bincode.deserialize(input, Result) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -93,6 +96,7 @@ class Result__Err(Result): INDEX = 1 # type: int value: str + Result.VARIANTS = [ Result__Ok, Result__Err, @@ -111,10 +115,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerInfo) @staticmethod - def bincode_deserialize(input: bytes) -> 'ServerInfo': + def bincode_deserialize(input: bytes) -> "ServerInfo": v, buffer = bincode.deserialize(input, ServerInfo) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -125,10 +129,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerResponse) @staticmethod - def bincode_deserialize(input: bytes) -> 'ServerResponse': + def bincode_deserialize(input: bytes) -> "ServerResponse": v, buffer = bincode.deserialize(input, ServerResponse) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -177,7 +181,9 @@ class ServerResponse__Get(ServerResponse): @dataclass(frozen=True) class ServerResponse__GetSimN(ServerResponse): INDEX = 7 # type: int - value: typing.Sequence[typing.Tuple["Array", typing.Dict[str, "MetadataValue"], "Similarity"]] + value: typing.Sequence[ + typing.Tuple["Array", typing.Dict[str, "MetadataValue"], "Similarity"] + ] @dataclass(frozen=True) @@ -191,6 +197,7 @@ class ServerResponse__CreateIndex(ServerResponse): INDEX = 9 # type: int value: st.uint64 + ServerResponse.VARIANTS = [ ServerResponse__Unit, ServerResponse__Pong, @@ -213,10 +220,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerResult) @staticmethod - def bincode_deserialize(input: bytes) -> 'ServerResult': + def bincode_deserialize(input: bytes) -> "ServerResult": v, buffer = bincode.deserialize(input, ServerResult) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -227,10 +234,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerType) @staticmethod - def bincode_deserialize(input: bytes) -> 'ServerType': + def bincode_deserialize(input: bytes) -> "ServerType": v, buffer = bincode.deserialize(input, ServerType) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -245,6 +252,7 @@ class ServerType__AI(ServerType): INDEX = 1 # type: int pass + ServerType.VARIANTS = [ ServerType__Database, ServerType__AI, @@ -259,10 +267,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Similarity) @staticmethod - def bincode_deserialize(input: bytes) -> 'Similarity': + def bincode_deserialize(input: bytes) -> "Similarity": v, buffer = bincode.deserialize(input, Similarity) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -276,10 +284,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StoreInfo) @staticmethod - def bincode_deserialize(input: bytes) -> 'StoreInfo': + def bincode_deserialize(input: bytes) -> "StoreInfo": v, buffer = bincode.deserialize(input, StoreInfo) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -292,10 +300,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StoreUpsert) @staticmethod - def bincode_deserialize(input: bytes) -> 'StoreUpsert': + def bincode_deserialize(input: bytes) -> "StoreUpsert": v, buffer = bincode.deserialize(input, StoreUpsert) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -308,10 +316,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, SystemTime) @staticmethod - def bincode_deserialize(input: bytes) -> 'SystemTime': + def bincode_deserialize(input: bytes) -> "SystemTime": v, buffer = bincode.deserialize(input, SystemTime) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -325,9 +333,8 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Version) @staticmethod - def bincode_deserialize(input: bytes) -> 'Version': + def bincode_deserialize(input: bytes) -> "Version": v, buffer = bincode.deserialize(input, Version) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v - diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_binary/__init__.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_binary/__init__.py index 0730bd23..a71b03f5 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_binary/__init__.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_binary/__init__.py @@ -7,8 +7,8 @@ Note: This internal module is currently only meant to share code between the BCS and bincode formats. Internal APIs could change in the future. """ -import dataclasses import collections +import dataclasses import io import typing from typing import get_type_hints diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_types/__init__.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_types/__init__.py index 6d72f027..1c85909c 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_types/__init__.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_types/__init__.py @@ -1,9 +1,10 @@ # Copyright (c) Facebook, Inc. and its affiliates # SPDX-License-Identifier: MIT OR Apache-2.0 -import numpy as np -from dataclasses import dataclass import typing +from dataclasses import dataclass + +import numpy as np class SerializationError(ValueError): diff --git a/sdk/ahnlich-client-py/demo_embed.py b/sdk/ahnlich-client-py/demo_embed.py index c5c5c302..c9423b74 100644 --- a/sdk/ahnlich-client-py/demo_embed.py +++ b/sdk/ahnlich-client-py/demo_embed.py @@ -8,14 +8,14 @@ ai_store_payload_no_predicates = { "store_name": "Diretnan Stores", - "query_model": ai_query.AIModel__ClipVitB32Text(), - "index_model": ai_query.AIModel__ClipVitB32Text(), + "query_model": ai_query.AIModel__AllMiniLML6V2(), + "index_model": ai_query.AIModel__AllMiniLML6V2(), } ai_store_payload_with_predicates = { "store_name": "Diretnan Predication Stores", - "query_model": ai_query.AIModel__ClipVitB32Text(), - "index_model": ai_query.AIModel__ClipVitB32Text(), + "query_model": ai_query.AIModel__AllMiniLML6V2(), + "index_model": ai_query.AIModel__AllMiniLML6V2(), "predicates": ["special", "brand"], } @@ -26,6 +26,13 @@ "predicates": ["special", "brand"], } +ai_store_payload_with_predicates_images_texts = { + "store_name": "Diretnan Image Text Predication Stores", + "query_model": ai_query.AIModel__ClipVitB32Text(), + "index_model": ai_query.AIModel__ClipVitB32Image(), + "predicates": ["special", "brand"], +} + def run_insert_text(): ai_client = AhnlichAIClient(address="127.0.0.1", port=1370, connect_timeout_sec=30) @@ -75,29 +82,16 @@ def run_get_simn_text(): return builder.exec() -def run_insert_image(): - image_urls = [ - ( - "https://cdn.britannica.com/96/195196-050-3909D5BD/Michael-Jordan-1988.jpg", - "Slam Dunk Jordan", - ), - ("https://i.ebayimg.com/images/g/0-wAAOSwsQ1h5Pqc/s-l1600.webp", "Air Jordan"), - ( - "https://as2.ftcdn.net/v2/jpg/02/70/86/51/1000_F_270865104_HMpmjP3Hqt0MvdlV7QkQJful50bBzj46.jpg", - "Aeroplane", - ), - ( - "https://csaenvironmental.co.uk/wp-content/uploads/2020/06/landscape-value-600x325.jpg", - "Landscape", - ), - ] - +def insert_image(urls, store_data): ai_client = AhnlichAIClient(address="127.0.0.1", port=1370, connect_timeout_sec=30) builder = ai_client.pipeline() - builder.create_store(**ai_store_payload_with_predicates_images) - for url, brand in image_urls: + builder.create_store(**store_data) + for url, brand in urls: print("Processing image: ", url) - location = urlopen(url) + if url.startswith("http"): + location = urlopen(url) + else: + location = url image = Image.open(location) buffer = io.BytesIO() image.save(buffer, format=image.format) @@ -110,7 +104,7 @@ def run_insert_image(): ] builder.set( - store_name=ai_store_payload_with_predicates_images["store_name"], + store_name=store_data["store_name"], inputs=store_inputs, preprocess_action=ai_query.PreprocessAction__Image( ai_query.ImageAction__ResizeImage() @@ -119,6 +113,24 @@ def run_insert_image(): return builder.exec() +def run_insert_image(): + image_urls = [ + ( + "https://cdn.britannica.com/96/195196-050-3909D5BD/Michael-Jordan-1988.jpg", + "Slam Dunk Jordan", + ), + ("https://i.ebayimg.com/images/g/0-wAAOSwsQ1h5Pqc/s-l1600.webp", "Air Jordan"), + ( + "https://as2.ftcdn.net/v2/jpg/02/70/86/51/1000_F_270865104_HMpmjP3Hqt0MvdlV7QkQJful50bBzj46.jpg", + "Aeroplane", + ), + ( + "https://csaenvironmental.co.uk/wp-content/uploads/2020/06/landscape-value-600x325.jpg", + "Landscape", + ), + ] + return insert_image(image_urls, ai_store_payload_with_predicates_images) + def run_get_simn_image(): ai_client = AhnlichAIClient(address="127.0.0.1", port=1370) @@ -134,3 +146,39 @@ def run_get_simn_image(): algorithm=ai_query.Algorithm__CosineSimilarity(), ) return builder.exec() + + +def run_insert_image_text(): + image_urls = [ + ( + "https://imageio.forbes.com/specials-images/imageserve/632357fbf1cebc1639065099/Roger-Federer-celebrated" + "-after-beating-Lorenzo-Sonego-at-Wimbledon-last-year-/1960x0.jpg?format=jpg&width=960", + "Roger Federer", + ), + ("https://www.silverarrows.net/wp-content/uploads/2020/05/Lewis-Hamilton-Japan.jpg", "Lewis Hamilton"), + ( + "https://img.20mn.fr/B2Dto_H3RveJTzabY4IR2yk/1444x920_andreja-laski-of-team-slovenia-and-clarisse-agbegnenou" + "-team-france-compete-during-the-women-63-kg-semifinal-of-table-b-contest-on-day-four-of-the-olympic-games-" + "paris-2024-at-champs-de-mars-arena-03vulaurent-d2317-credit-laurent-vu-sipa-2407301738", + "Clarisse Agbegnenou", + ), + ( + "https://c8.alamy.com/comp/R1YEE4/london-uk-15th-november-2018-jadon-sancho-of-england-is-tackled-by-" + "christian-pulisic-of-usa-during-the-international-friendly-match-between-england-and-usa-at-wembley-" + "stadium-on-november-15th-2018-in-london-england-photo-by-matt-bradshawphcimages-credit-phc-imagesalamy-live-news-R1YEE4.jpg", + "Christian Pulisic and Sancho", + ), + ] + return insert_image(image_urls, ai_store_payload_with_predicates_images_texts) + + +def run_get_simn_image_text(): + ai_client = AhnlichAIClient(address="127.0.0.1", port=1370) + builder = ai_client.pipeline() + builder.get_sim_n( + store_name=ai_store_payload_with_predicates_images_texts["store_name"], + search_input=ai_query.StoreInput__RawString("United States vs England"), + closest_n=3, + algorithm=ai_query.Algorithm__CosineSimilarity(), + ) + return builder.exec() From 0a10b48b6c128d5b0c7ee5382f90817f9994f8e5 Mon Sep 17 00:00:00 2001 From: HabeebShopeju Date: Sun, 24 Nov 2024 07:36:05 +0000 Subject: [PATCH 04/15] Set up ORT for Text, ORT for Image yet to run --- ahnlich/ai/src/cli/server.rs | 1 + ahnlich/ai/src/engine/ai/models.rs | 55 ++- .../ai/src/engine/ai/providers/fastembed.rs | 219 ----------- ahnlich/ai/src/engine/ai/providers/mod.rs | 11 +- ahnlich/ai/src/engine/ai/providers/ort.rs | 344 ++++++++++-------- .../ai/src/engine/ai/providers/ort_helper.rs | 33 +- .../engine/ai/providers/ort_text_helper.rs | 116 ------ .../ai/providers/processors/center_crop.rs | 10 +- .../processors/imagearray_to_ndarray.rs | 10 +- .../src/engine/ai/providers/processors/mod.rs | 25 +- .../ai/providers/processors/normalize.rs | 42 ++- .../engine/ai/providers/processors/pooling.rs | 71 ++++ .../ai/providers/processors/postprocessor.rs | 81 +++++ .../ai/providers/processors/preprocessor.rs | 98 +++-- .../engine/ai/providers/processors/rescale.rs | 10 +- .../engine/ai/providers/processors/resize.rs | 10 +- .../ai/providers/processors/tokenize.rs | 144 ++++++++ ahnlich/ai/src/error.rs | 44 ++- ahnlich/ai/src/manager/mod.rs | 178 ++++----- ahnlich/ai/src/server/task.rs | 9 +- ahnlich/ai/src/tests/aiproxy_test.rs | 18 +- ahnlich/client/src/ai.rs | 6 +- ahnlich/client/src/lib.rs | 14 +- ahnlich/dsl/src/ai.rs | 12 +- ahnlich/dsl/src/syntax/syntax.pest | 9 +- ahnlich/dsl/src/tests/ai.rs | 4 +- ahnlich/types/src/ai/mod.rs | 2 +- ahnlich/types/src/ai/preprocess.rs | 27 +- 28 files changed, 855 insertions(+), 748 deletions(-) delete mode 100644 ahnlich/ai/src/engine/ai/providers/fastembed.rs delete mode 100644 ahnlich/ai/src/engine/ai/providers/ort_text_helper.rs create mode 100644 ahnlich/ai/src/engine/ai/providers/processors/pooling.rs create mode 100644 ahnlich/ai/src/engine/ai/providers/processors/postprocessor.rs create mode 100644 ahnlich/ai/src/engine/ai/providers/processors/tokenize.rs diff --git a/ahnlich/ai/src/cli/server.rs b/ahnlich/ai/src/cli/server.rs index 50802f7f..d788b98d 100644 --- a/ahnlich/ai/src/cli/server.rs +++ b/ahnlich/ai/src/cli/server.rs @@ -135,6 +135,7 @@ impl Default for AIProxyConfig { SupportedModels::AllMiniLML6V2, SupportedModels::AllMiniLML12V2, SupportedModels::BGEBaseEnV15, + SupportedModels::BGELargeEnV15, SupportedModels::ClipVitB32Text, SupportedModels::Resnet50, SupportedModels::ClipVitB32Image, diff --git a/ahnlich/ai/src/engine/ai/models.rs b/ahnlich/ai/src/engine/ai/models.rs index 093f45ce..2590c37a 100644 --- a/ahnlich/ai/src/engine/ai/models.rs +++ b/ahnlich/ai/src/engine/ai/models.rs @@ -1,5 +1,4 @@ use crate::cli::server::SupportedModels; -use crate::engine::ai::providers::fastembed::FastEmbedProvider; use crate::engine::ai::providers::ort::ORTProvider; use crate::engine::ai::providers::ModelProviders; use crate::engine::ai::providers::ProviderTrait; @@ -20,6 +19,7 @@ use std::path::Path; use strum::Display; use serde::ser::Error as SerError; use serde::de::Error as DeError; +use tokenizers::Encoding; #[derive(Display)] pub enum ModelType { @@ -47,7 +47,7 @@ impl From<&AIModel> for Model { model_type: ModelType::Text { max_input_tokens: nonzero!(256usize), }, - provider: ModelProviders::FastEmbed(FastEmbedProvider::new()), + provider: ModelProviders::ORT(ORTProvider::new()), supported_model: SupportedModels::AllMiniLML6V2, description: String::from("Sentence Transformer model, with 6 layers, version 2"), embedding_size: nonzero!(384usize), @@ -57,7 +57,7 @@ impl From<&AIModel> for Model { // Token size source: https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2#intended-uses max_input_tokens: nonzero!(256usize), }, - provider: ModelProviders::FastEmbed(FastEmbedProvider::new()), + provider: ModelProviders::ORT(ORTProvider::new()), supported_model: SupportedModels::AllMiniLML12V2, description: String::from("Sentence Transformer model, with 12 layers, version 2."), embedding_size: nonzero!(384usize), @@ -67,7 +67,7 @@ impl From<&AIModel> for Model { // Token size source: https://huggingface.co/BAAI/bge-large-en/discussions/11#64e44de1623074ac850aa1ae max_input_tokens: nonzero!(512usize), }, - provider: ModelProviders::FastEmbed(FastEmbedProvider::new()), + provider: ModelProviders::ORT(ORTProvider::new()), supported_model: SupportedModels::BGEBaseEnV15, description: String::from( "BAAI General Embedding model with English support, base scale, version 1.5.", @@ -78,7 +78,7 @@ impl From<&AIModel> for Model { model_type: ModelType::Text { max_input_tokens: nonzero!(512usize), }, - provider: ModelProviders::FastEmbed(FastEmbedProvider::new()), + provider: ModelProviders::ORT(ORTProvider::new()), supported_model: SupportedModels::BGELargeEnV15, description: String::from( "BAAI General Embedding model with English support, large scale, version 1.5.", @@ -134,14 +134,11 @@ impl Model { #[tracing::instrument(skip(self))] pub fn model_ndarray( &self, - storeinput: Vec, + modelinput: ModelInput, action_type: &InputAction, ) -> Result, AIProxyError> { let store_keys = match &self.provider { - ModelProviders::FastEmbed(provider) => { - provider.run_inference(storeinput, action_type)? - } - ModelProviders::ORT(provider) => provider.run_inference(storeinput, action_type)?, + ModelProviders::ORT(provider) => provider.run_inference(modelinput, action_type)?, }; Ok(store_keys) } @@ -173,10 +170,6 @@ impl Model { pub fn setup_provider(&mut self, cache_location: &Path) { let supported_model = self.supported_model; match &mut self.provider { - ModelProviders::FastEmbed(provider) => { - provider.set_model(&supported_model); - provider.set_cache_location(cache_location); - } ModelProviders::ORT(provider) => { provider.set_model(&supported_model); provider.set_cache_location(cache_location); @@ -186,9 +179,6 @@ impl Model { pub fn load(&mut self) -> Result<(), AIProxyError> { match &mut self.provider { - ModelProviders::FastEmbed(provider) => { - provider.load_model()?; - } ModelProviders::ORT(provider) => { provider.load_model()?; } @@ -198,9 +188,6 @@ impl Model { pub fn get(&self) -> Result<(), AIProxyError> { match &self.provider { - ModelProviders::FastEmbed(provider) => { - provider.get_model()?; - } ModelProviders::ORT(provider) => { provider.get_model()?; } @@ -258,8 +245,8 @@ impl fmt::Display for InputAction { #[derive(Debug)] pub enum ModelInput { - Text(String), - Image(ImageArray), + Texts(Vec), + Images(Vec), } #[derive(Debug, Clone)] @@ -380,22 +367,22 @@ impl<'de> Deserialize<'de> for ImageArray { } } -impl TryFrom for ModelInput { - type Error = AIProxyError; - - fn try_from(value: StoreInput) -> Result { - match value { - StoreInput::RawString(s) => Ok(ModelInput::Text(s)), - StoreInput::Image(bytes) => Ok(ModelInput::Image(ImageArray::try_new(bytes)?)), - } - } -} +// impl TryFrom for ModelInput { +// type Error = AIProxyError; +// +// fn try_from(value: StoreInput) -> Result { +// match value { +// StoreInput::RawString(s) => Ok(ModelInput::Text(s)), +// StoreInput::Image(bytes) => Ok(ModelInput::Image(ImageArray::try_new(bytes)?)), +// } +// } +// } impl From<&ModelInput> for AIStoreInputType { fn from(value: &ModelInput) -> AIStoreInputType { match value { - ModelInput::Text(_) => AIStoreInputType::RawString, - ModelInput::Image(_) => AIStoreInputType::Image, + ModelInput::Texts(_) => AIStoreInputType::RawString, + ModelInput::Images(_) => AIStoreInputType::Image, } } } diff --git a/ahnlich/ai/src/engine/ai/providers/fastembed.rs b/ahnlich/ai/src/engine/ai/providers/fastembed.rs deleted file mode 100644 index 0ea2431f..00000000 --- a/ahnlich/ai/src/engine/ai/providers/fastembed.rs +++ /dev/null @@ -1,219 +0,0 @@ -use crate::cli::server::SupportedModels; -use crate::engine::ai::models::{ImageArray, InputAction, Model, ModelInput, ModelType}; -use crate::engine::ai::providers::{ProviderTrait, TextPreprocessorTrait}; -use crate::error::AIProxyError; -use ahnlich_types::ai::AIStoreInputType; -use ahnlich_types::keyval::StoreKey; -use fastembed::{EmbeddingModel, ImageEmbedding, InitOptions, TextEmbedding}; -use hf_hub::{api::sync::ApiBuilder, Cache}; -use ndarray::Array1; -use rayon::iter::Either; -use rayon::prelude::*; -use std::convert::TryFrom; -use std::fmt; -use std::path::{Path, PathBuf}; -use tiktoken_rs::{cl100k_base, CoreBPE}; - -#[derive(Default)] -pub struct FastEmbedProvider { - cache_location: Option, - cache_location_extension: PathBuf, - supported_models: Option, - pub preprocessor: Option, - pub model: Option, -} - -struct Tokenizer(CoreBPE); - -// TODO (HAKSOAT): Implement Tokenizers specific to models -impl Tokenizer { - fn new(supported_models: &SupportedModels) -> Result { - let _ = supported_models; - // Using ChatGPT model tokenizers as a default till we add specific implementations. - let bpe = cl100k_base().map_err(|e| AIProxyError::TokenizerInitError(e.to_string()))?; - Ok(Tokenizer(bpe)) - } -} - -pub struct FastEmbedPreprocessor { - tokenizer: Tokenizer, -} - -// TODO (HAKSOAT): Implement other preprocessors -impl TextPreprocessorTrait for FastEmbedProvider { - fn encode_str(&self, text: &str) -> Result, AIProxyError> { - let preprocessor = self.preprocessor.as_ref() - .ok_or(AIProxyError::AIModelNotInitialized)?; - let tokens = preprocessor - .tokenizer.0.encode_with_special_tokens(text); - Ok(tokens.iter().map(|token| *token as u32).collect()) - } - - fn decode_tokens(&self, tokens: Vec) -> Result { - let preprocessor = self.preprocessor.as_ref() - .ok_or(AIProxyError::AIModelNotInitialized)?; - let tokens = tokens.iter().map(|token| *token as usize).collect(); - let text = preprocessor - .tokenizer.0.decode(tokens) - .map_err(|_| AIProxyError::ModelTokenizationError)?; - Ok(text) - } -} - -pub enum FastEmbedModelType { - Text(EmbeddingModel), -} - -impl fmt::Debug for FastEmbedProvider { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("FastEmbedProvider") - .field("cache_location", &self.cache_location) - .field("cache_location_extension", &self.cache_location_extension) - .field("supported_models", &self.supported_models) - .finish() - } -} - -pub enum FastEmbedModel { - Text(TextEmbedding), - Image(ImageEmbedding), -} - -impl TryFrom<&SupportedModels> for FastEmbedModelType { - type Error = AIProxyError; - - fn try_from(model: &SupportedModels) -> Result { - let model_modality: Model = model.into(); - let model_type = match model_modality.model_type { - ModelType::Text { .. } => { - let model_type = match model { - SupportedModels::AllMiniLML6V2 => EmbeddingModel::AllMiniLML6V2, - SupportedModels::AllMiniLML12V2 => EmbeddingModel::AllMiniLML12V2, - SupportedModels::BGEBaseEnV15 => EmbeddingModel::BGEBaseENV15, - SupportedModels::BGELargeEnV15 => EmbeddingModel::BGELargeENV15, - _ => return Err(AIProxyError::AIModelNotSupported { model_name: model.to_string() }), - }; - FastEmbedModelType::Text(model_type) - } - _ => return Err(AIProxyError::AIModelNotSupported { model_name: model.to_string() }), - }; - Ok(model_type) - } -} - -impl FastEmbedProvider { - pub(crate) fn new() -> Self { - FastEmbedProvider { - cache_location: None, - cache_location_extension: PathBuf::from("huggingface".to_string()), - supported_models: None, - model: None, - preprocessor: None, - } - } - - fn load_tokenizer(&mut self) -> Result<(), AIProxyError> { - let Some(supported_models) = self.supported_models else { - return Err(AIProxyError::AIModelNotInitialized); - }; - let tokenizer = Tokenizer::new(&supported_models)?; - self.preprocessor = Some(FastEmbedPreprocessor { tokenizer }); - Ok(()) - } -} - -impl ProviderTrait for FastEmbedProvider { - fn set_cache_location(&mut self, location: &Path) { - self.cache_location = Some(location.join(self.cache_location_extension.clone())); - } - - fn set_model(&mut self, model: &SupportedModels) { - self.supported_models = Some(*model); - } - - fn load_model(&mut self) -> Result<(), AIProxyError> { - let Some(cache_location) = self.cache_location.clone() else { - return Err(AIProxyError::CacheLocationNotInitiailized); - }; - let Some(model_type) = self.supported_models else { - return Err(AIProxyError::AIModelNotInitialized); - }; - let model_type = FastEmbedModelType::try_from(&model_type)?; - let FastEmbedModelType::Text(model_type) = model_type; - let model = - TextEmbedding::try_new(InitOptions::new(model_type).with_cache_dir(cache_location)) - .map_err(|e| AIProxyError::TextEmbeddingInitError(e.to_string()))?; - self.model = Some(FastEmbedModel::Text(model)); - self.load_tokenizer() - } - - fn get_model(&self) -> Result<(), AIProxyError> { - let Some(cache_location) = self.cache_location.clone() else { - return Err(AIProxyError::CacheLocationNotInitiailized); - }; - let Some(model_type) = self.supported_models else { - return Err(AIProxyError::AIModelNotInitialized); - }; - let model_type = FastEmbedModelType::try_from(&model_type)?; - match model_type { - FastEmbedModelType::Text(model_type) => { - let cache = Cache::new(cache_location); - let api = ApiBuilder::from_cache(cache) - .with_progress(true) - .build() - .map_err(|e| AIProxyError::APIBuilderError(e.to_string()))?; - let model_repo = api.model(model_type.to_string()); - let model_info = TextEmbedding::get_model_info(&model_type) - .map_err(|e| AIProxyError::TextEmbeddingInitError(e.to_string()))?; - model_repo - .get(model_info.model_file.as_str()) - .map_err(|e| AIProxyError::APIBuilderError(e.to_string()))?; - Ok(()) - } - } - } - - fn run_inference( - &self, - inputs: Vec, - action_type: &InputAction, - ) -> Result, AIProxyError> { - return if let Some(fastembed_model) = &self.model { - let (string_inputs, image_inputs): (Vec, Vec) = - inputs.into_par_iter().partition_map(|input| match input { - ModelInput::Text(value) => Either::Left(value), - ModelInput::Image(value) => Either::Right(value), - }); - - if !image_inputs.is_empty() { - let store_input_type: AIStoreInputType = AIStoreInputType::Image; - let Some(index_model_repr) = self.supported_models else { - return Err(AIProxyError::AIModelNotInitialized); - }; - let index_model_repr: Model = (&index_model_repr).into(); - return Err(AIProxyError::StoreTypeMismatchError { - action: *action_type, - index_model_type: index_model_repr.input_type(), - storeinput_type: store_input_type, - }); - } - let FastEmbedModel::Text(model) = fastembed_model else { - return Err(AIProxyError::AIModelNotSupported { model_name: self.supported_models.unwrap().to_string() }); - }; - let batch_size = 16; - let store_keys = model - .embed(string_inputs, Some(batch_size)) - .map_err(|e| AIProxyError::ModelProviderRunInferenceError(e.to_string()))? - .iter() - .try_fold(Vec::new(), |mut accumulator, embedding| { - accumulator.push(StoreKey(>::from(embedding.to_owned()))); - Ok(accumulator) - }); - store_keys - } else { - Err(AIProxyError::AIModelNotSupported { - model_name: self.supported_models.unwrap().to_string(), - }) - }; - } -} diff --git a/ahnlich/ai/src/engine/ai/providers/mod.rs b/ahnlich/ai/src/engine/ai/providers/mod.rs index c241601a..8aa78bd2 100644 --- a/ahnlich/ai/src/engine/ai/providers/mod.rs +++ b/ahnlich/ai/src/engine/ai/providers/mod.rs @@ -1,12 +1,9 @@ -pub(crate) mod fastembed; pub(crate) mod ort; -mod ort_text_helper; mod ort_helper; mod processors; use crate::cli::server::SupportedModels; use crate::engine::ai::models::{InputAction, ModelInput}; -use crate::engine::ai::providers::fastembed::FastEmbedProvider; use crate::engine::ai::providers::ort::ORTProvider; use crate::error::AIProxyError; use ahnlich_types::keyval::StoreKey; @@ -16,7 +13,6 @@ use strum::EnumIter; #[derive(Debug, EnumIter)] pub enum ModelProviders { - FastEmbed(FastEmbedProvider), ORT(ORTProvider), } @@ -27,12 +23,7 @@ pub trait ProviderTrait: std::fmt::Debug + Send + Sync { fn get_model(&self) -> Result<(), AIProxyError>; fn run_inference( &self, - input: Vec, + input: ModelInput, action_type: &InputAction, ) -> Result, AIProxyError>; } - -pub trait TextPreprocessorTrait { - fn encode_str(&self, text: &str) -> Result, AIProxyError>; - fn decode_tokens(&self, tokens: Vec) -> Result; -} diff --git a/ahnlich/ai/src/engine/ai/providers/ort.rs b/ahnlich/ai/src/engine/ai/providers/ort.rs index 0f4851e8..f356d71a 100644 --- a/ahnlich/ai/src/engine/ai/providers/ort.rs +++ b/ahnlich/ai/src/engine/ai/providers/ort.rs @@ -1,8 +1,7 @@ use crate::cli::server::SupportedModels; use crate::engine::ai::models::{ImageArray, InputAction, Model, ModelInput}; -use crate::engine::ai::providers::ort_text_helper::{get_tokenizer_artifacts_hf_hub, load_tokenizer_artifacts_hf_hub, - normalize}; -use crate::engine::ai::providers::{ProviderTrait, TextPreprocessorTrait}; +use crate::engine::ai::providers::processors::tokenize::Tokenize; +use crate::engine::ai::providers::ProviderTrait; use crate::error::AIProxyError; use fallible_collections::FallibleVec; use hf_hub::{api::sync::ApiBuilder, Cache}; @@ -12,19 +11,26 @@ use ort::{Session, Value}; use rayon::prelude::*; use ahnlich_types::keyval::StoreKey; -use ndarray::{Array, Array1, Axis, Ix4}; +use ndarray::{Array, Array1, Axis, Ix2, Ix3, Ix4, IxDyn, IxDynImpl}; use std::convert::TryFrom; use std::default::Default; use std::fmt; use std::path::{Path, PathBuf}; use std::thread::available_parallelism; -use crate::engine::ai::providers::processors::preprocessor::{ImagePreprocessorFiles, ORTImagePreprocessor, ORTPreprocessor, ORTTextPreprocessor}; +use tokenizers::Encoding; +use crate::engine::ai::providers::processors::preprocessor::{ImagePreprocessorFiles, ORTImagePreprocessor, ORTPreprocessor, ORTTextPreprocessor, TextPreprocessorFiles}; +use crate::engine::ai::providers::ort_helper::normalize; +use ndarray::s; +use tokenizers::Tokenizer; +use crate::engine::ai::providers::processors::postprocessor::{ORTPostprocessor, ORTTextPostprocessor}; + #[derive(Default)] pub struct ORTProvider { cache_location: Option, cache_location_extension: PathBuf, supported_models: Option, pub preprocessor: Option, + pub postprocessor: Option, pub model: Option, } @@ -53,8 +59,7 @@ pub struct ORTTextModel { repo_name: String, weights_file: String, session: Option, - input_params: Vec, - output_param: String, + preprocessor_files: TextPreprocessorFiles } pub enum ORTModel { @@ -85,8 +90,26 @@ impl TryFrom<&SupportedModels> for ORTModel { SupportedModels::ClipVitB32Text => Ok(ORTModel::Text(ORTTextModel { repo_name: "Qdrant/clip-ViT-B-32-text".to_string(), weights_file: "model.onnx".to_string(), - input_params: vec!["input_ids".to_string(), "attention_mask".to_string()], - output_param: "text_embeds".to_string(), + ..Default::default() + })), + SupportedModels::AllMiniLML6V2 => Ok(ORTModel::Text(ORTTextModel { + repo_name: "Qdrant/all-MiniLM-L6-v2-onnx".to_string(), + weights_file: "model.onnx".to_string(), + ..Default::default() + })), + SupportedModels::AllMiniLML12V2 => Ok(ORTModel::Text(ORTTextModel { + repo_name: "Xenova/all-MiniLM-L12-v2".to_string(), + weights_file: "onnx/model.onnx".to_string(), + ..Default::default() + })), + SupportedModels::BGEBaseEnV15 => Ok(ORTModel::Text(ORTTextModel { + repo_name: "Xenova/bge-base-en-v1.5".to_string(), + weights_file: "onnx/model.onnx".to_string(), + ..Default::default() + })), + SupportedModels::BGELargeEnV15 => Ok(ORTModel::Text(ORTTextModel { + repo_name: "Xenova/bge-large-en-v1.5".to_string(), + weights_file: "onnx/model.onnx".to_string(), ..Default::default() })), _ => Err(AIProxyError::AIModelNotSupported { @@ -106,10 +129,15 @@ impl ORTProvider { preprocessor: None, supported_models: None, model: None, + postprocessor: None, } } - pub fn preprocess(&self, data: Vec) -> Result, AIProxyError> { + fn get_postprocessor() -> Result<(), AIProxyError> { + Ok(()) + } + + pub fn preprocess_images(&self, data: Vec) -> Result, AIProxyError> { match &self.preprocessor { Some(ORTPreprocessor::Image(preprocessor)) => { let output_data = preprocessor.process(data) @@ -124,12 +152,75 @@ impl ORTProvider { } } + pub fn preprocess_texts(&self, data: Vec, truncate: bool) -> Result, AIProxyError> { + match &self.preprocessor { + Some(ORTPreprocessor::Text(preprocessor)) => { + let output_data = preprocessor.process(data, truncate) + .map_err( + |e| AIProxyError::ModelProviderPreprocessingError( + format!("Preprocessing failed for {:?} with error: {}", + self.supported_models.unwrap().to_string(), e) + ))?; + Ok(output_data) + } + _ => Err(AIProxyError::ModelPreprocessingError { + model_name: self.supported_models.unwrap().to_string(), + message: "Preprocessor not initialized".to_string(), + }) + } + } + + pub fn postprocess_text_embeddings(&self, embeddings: Array, attention_mask: Array) -> Result, AIProxyError> { + let embeddings = match embeddings.shape().len() { + 3 => { + let existing_shape = embeddings.shape().to_vec(); + Ok(embeddings.into_dimensionality() + .map_err( + |e| AIProxyError::ModelPostprocessingError { + model_name: self.supported_models.unwrap().to_string(), + message: format!("Unable to convert into 3D array. Existing shape {:?}. {:?}", existing_shape, e.to_string()) + })?.to_owned()) + } + 2 => { + let existing_shape = embeddings.shape().to_vec(); + let intermediate = embeddings.into_dimensionality() + .map_err( + |e| AIProxyError::ModelPostprocessingError { + model_name: self.supported_models.unwrap().to_string(), + message: format!("Unable to convert into 2D. Existing shape {:?}. {:?}", existing_shape, e.to_string()) + })?.to_owned(); + return Ok(intermediate) + } + _ => { + Err(AIProxyError::ModelPostprocessingError { + model_name: self.supported_models.unwrap().to_string(), + message: format!("Unsupported shape for postprocessing. Shape: {:?}", embeddings.shape()) + }) + } + }?; + match &self.postprocessor { + Some(ORTPostprocessor::Text(postprocessor)) => { + let output_data = postprocessor.process(embeddings, attention_mask) + .map_err( + |e| AIProxyError::ModelProviderPostprocessingError( + format!("Postprocessing failed for {:?} with error: {}", + self.supported_models.unwrap().to_string(), e) + ))?; + Ok(output_data) + } + _ => Err(AIProxyError::ModelPostprocessingError { + model_name: self.supported_models.unwrap().to_string(), + message: "Postprocessor not initialized".to_string(), + }) + } + } + pub fn batch_inference_image(&self, inputs: Vec) -> Result, AIProxyError> { let model = match &self.model { Some(ORTModel::Image(model)) => model, _ => return Err(AIProxyError::AIModelNotSupported { model_name: self.supported_models.unwrap().to_string() }), }; - let pixel_values_array = self.preprocess(inputs)?; + let pixel_values_array = self.preprocess_images(inputs)?; match &model.session { Some(session) => { let session_inputs = ort::inputs![ @@ -161,84 +252,97 @@ impl ORTProvider { } } - pub fn batch_inference_text(&self, inputs: Vec) -> Result, AIProxyError> { - let inputs = inputs.iter().map(|x| x.as_str()).collect::>(); + pub fn batch_inference_text(&self, encodings: Vec) -> Result, AIProxyError> { let model = match &self.model { Some(ORTModel::Text(model)) => model, _ => return Err(AIProxyError::AIModelNotSupported { model_name: self.supported_models.unwrap().to_string() }), }; - let batch_size = inputs.len(); - let encodings = match &self.preprocessor { - Some(ORTPreprocessor::Text(preprocessor)) => { - // TODO: We encode tokens at the preprocess step early in the workflow then also encode here. - // Find a way to store those encoded tokens for reuse here. - preprocessor.tokenizer.encode_batch(inputs, true).map_err(|_| { - AIProxyError::ModelTokenizationError - })? - } - _ => return Err(AIProxyError::AIModelNotInitialized) - }; - + let batch_size = encodings.len(); // Extract the encoding length and batch size let encoding_length = encodings[0].len(); - let max_size = encoding_length * batch_size; - // Preallocate arrays with the maximum size - let mut ids_array = Vec::with_capacity(max_size); - let mut mask_array = Vec::with_capacity(max_size); - - // Not using par_iter because the closure needs to be FnMut - encodings.iter().for_each(|encoding| { - let ids = encoding.get_ids(); - let mask = encoding.get_attention_mask(); - - // Extend the preallocated arrays with the current encoding - // Requires the closure to be FnMut - ids_array.extend(ids.iter().map(|x| *x as i64)); - mask_array.extend(mask.iter().map(|x| *x as i64)); - }); - - // Create CowArrays from vectors - let inputs_ids_array = - Array::from_shape_vec((batch_size, encoding_length), ids_array) - .map_err(|e| { - AIProxyError::ModelProviderPreprocessingError(e.to_string()) - })?; - - let attention_mask_array = - Array::from_shape_vec((batch_size, encoding_length), mask_array).map_err(|e| { - AIProxyError::ModelProviderPreprocessingError(e.to_string()) - })?; - match &model.session { Some(session) => { - let session_inputs = ort::inputs![ - model.input_params.first() - .expect("Hardcoded in parameters").as_str() => Value::from_array(inputs_ids_array)?, - model.input_params.get(1) - .expect("Hardcoded in parameters").as_str() => Value::from_array(attention_mask_array.view())? - ].map_err(|e| AIProxyError::ModelProviderPreprocessingError(e.to_string()))?; - let outputs = session.run(session_inputs) - .map_err(|e| AIProxyError::ModelProviderRunInferenceError(e.to_string()))?; - let last_hidden_state_key = match outputs.len() { - 1 => outputs - .keys() - .next() - .expect("Should not happen as length was checked"), - _ => model.output_param.as_str(), + let need_token_type_ids = session + .inputs + .iter() + .any(|input| input.name == "token_type_ids"); + // Preallocate arrays with the maximum size + let mut ids_array = Vec::with_capacity(max_size); + let mut mask_array = Vec::with_capacity(max_size); + let mut token_type_ids_array: Option> = None; + if need_token_type_ids { + token_type_ids_array = Some(Vec::with_capacity(max_size)); + } + + // Not using par_iter because the closure needs to be FnMut + encodings.iter().for_each(|encoding| { + let ids = encoding.get_ids(); + let mask = encoding.get_attention_mask(); + + // Extend the preallocated arrays with the current encoding + // Requires the closure to be FnMut + ids_array.extend(ids.iter().map(|x| *x as i64)); + mask_array.extend(mask.iter().map(|x| *x as i64)); + match token_type_ids_array { + Some(ref mut token_type_ids_array) => { + token_type_ids_array.extend(encoding.get_type_ids().iter().map(|x| *x as i64)); + } + None => {} + } + }); + + // Create CowArrays from vectors + let inputs_ids_array = + Array::from_shape_vec((batch_size, encoding_length), ids_array) + .map_err(|e| { + AIProxyError::ModelProviderPreprocessingError(e.to_string()) + })?; + + let attention_mask_array = + Array::from_shape_vec((batch_size, encoding_length), mask_array).map_err(|e| { + AIProxyError::ModelProviderPreprocessingError(e.to_string()) + })?; + + let token_type_ids_array = match token_type_ids_array { + Some(token_type_ids_array) => { + Some(Array::from_shape_vec((batch_size, encoding_length), token_type_ids_array) + .map_err(|e| { + AIProxyError::ModelProviderPreprocessingError(e.to_string()) + })?) + }, + None => None, }; - let output_data = outputs[last_hidden_state_key] + let mut session_inputs = ort::inputs![ + "input_ids" => Value::from_array(inputs_ids_array)?, + "attention_mask" => Value::from_array(attention_mask_array.view())? + ].map_err(|e| AIProxyError::ModelProviderPreprocessingError(e.to_string()))?; + match token_type_ids_array { + Some(token_type_ids_array) => { + session_inputs.push(( + "token_type_ids".into(), + Value::from_array(token_type_ids_array)?.into(), + )); + } + None => {} + } + + let output_key = session.outputs.first().expect("Must exist").name.clone(); + let session_outputs = session.run(session_inputs) + .map_err(|e| AIProxyError::ModelProviderRunInferenceError(e.to_string()))?; + let session_output = session_outputs[output_key.as_str()] .try_extract_tensor::() .map_err(|e| AIProxyError::ModelProviderPostprocessingError(e.to_string()))?; - let store_keys = output_data + let session_output = session_output + .to_owned(); + let embeddings = self.postprocess_text_embeddings(session_output, attention_mask_array)?; + println!("Embeddings: {:?}", embeddings); + let store_keys = embeddings .axis_iter(Axis(0)) .into_par_iter() - .map(|row| { - let embeddings = normalize(row.as_slice().unwrap()); - StoreKey(>::from(embeddings)) - }) + .map(|embedding| StoreKey(>::from(embedding.to_owned()))) .collect(); Ok(store_keys) } @@ -247,37 +351,6 @@ impl ORTProvider { } } -impl TextPreprocessorTrait for ORTProvider { - fn encode_str(&self, text: &str) -> Result, AIProxyError> { - match &self.model { - Some(ORTModel::Text(model)) => model, - _ => return Err(AIProxyError::AIModelNotSupported { model_name: self.supported_models.unwrap().to_string() }), - }; - - let Some(ORTPreprocessor::Text(preprocessor)) = &self.preprocessor else { - return Err(AIProxyError::AIModelNotInitialized); - }; - - let tokens = preprocessor.tokenizer.encode(text, true) - .map_err(|_| {AIProxyError::ModelTokenizationError})?; - let token_ids = tokens.get_ids(); - Ok(token_ids.to_vec()) - } - - fn decode_tokens(&self, token_ids: Vec) -> Result { - match &self.model { - Some(ORTModel::Text(model)) => model, - _ => return Err(AIProxyError::AIModelNotSupported { model_name: self.supported_models.unwrap().to_string() }), - }; - - let Some(ORTPreprocessor::Text(preprocessor)) = &self.preprocessor else { - return Err(AIProxyError::AIModelNotInitialized); - }; - - preprocessor.tokenizer.decode(&token_ids, true) - .map_err(|_| {AIProxyError::ModelTokenizationError}) - } -} impl ProviderTrait for ORTProvider { fn set_cache_location(&mut self, location: &Path) { @@ -329,6 +402,7 @@ impl ProviderTrait for ORTProvider { input_params: input_param, output_param, session: Some(session), + preprocessor_files: preprocessor_files.clone(), ..Default::default() })); let mut preprocessor = ORTImagePreprocessor::default(); @@ -339,8 +413,7 @@ impl ProviderTrait for ORTProvider { ORTModel::Text(ORTTextModel { weights_file, repo_name, - input_params, - output_param, + preprocessor_files, .. }) => { let model_repo = api.model(repo_name.clone()); @@ -350,23 +423,16 @@ impl ProviderTrait for ORTProvider { let session = Session::builder()? .with_intra_threads(threads)? .commit_from_file(model_file_reference)?; - let max_token_length = Model::from(&(self.supported_models - .ok_or(AIProxyError::AIModelNotInitialized)?)) - .max_input_token() - .ok_or(AIProxyError::AIModelNotInitialized)?; - let tokenizer = load_tokenizer_artifacts_hf_hub(&model_repo, - usize::from(max_token_length))?; self.model = Some(ORTModel::Text(ORTTextModel { repo_name, weights_file, - input_params, - output_param, session: Some(session), + preprocessor_files: preprocessor_files.clone(), })); - self.preprocessor = Some(ORTPreprocessor::Text( - ORTTextPreprocessor { - tokenizer, - })); + let preprocessor = ORTTextPreprocessor::load(model_repo, preprocessor_files)?; + self.preprocessor = Some(ORTPreprocessor::Text(preprocessor)); + let postprocessor = ORTTextPostprocessor::load(supported_model)?; + self.postprocessor = Some(ORTPostprocessor::Text(postprocessor)); } } Ok(()) @@ -404,10 +470,11 @@ impl ProviderTrait for ORTProvider { }, ORTModel::Text(ORTTextModel { repo_name, + preprocessor_files, .. }) => { - let model_repo = api.model(repo_name); - get_tokenizer_artifacts_hf_hub(&model_repo)?; + let model_repo = api.model(repo_name.clone()); + Tokenize::download_artifacts(preprocessor_files.tokenize, model_repo)?; Ok(()) } } @@ -415,31 +482,22 @@ impl ProviderTrait for ORTProvider { fn run_inference( &self, - inputs: Vec, + input: ModelInput, _action_type: &InputAction, ) -> Result, AIProxyError> { - let (string_inputs, image_inputs): (Vec, Vec) = - inputs.into_par_iter().partition_map(|input| match input { - ModelInput::Text(value) => Either::Left(value), - ModelInput::Image(value) => Either::Right(value), - }); - - if !image_inputs.is_empty() && !string_inputs.is_empty() { - return Err(AIProxyError::VaryingInferenceInputTypes) - } - let batch_size = 16; - let mut store_keys: Vec<_> = FallibleVec::try_with_capacity( - image_inputs.len().max(string_inputs.len()) - )?; - if !image_inputs.is_empty() { - for batch_inputs in image_inputs.into_iter().chunks(batch_size).into_iter() { - store_keys.extend(self.batch_inference_image(batch_inputs.collect())?); - } - } else { - for batch_inputs in string_inputs.into_iter().chunks(batch_size).into_iter() { - store_keys.extend(self.batch_inference_text(batch_inputs.collect())?); - } + + match input { + ModelInput::Images(images) => self.batch_inference_image(images), + ModelInput::Texts(encodings) => { + let mut store_keys: Vec<_> = FallibleVec::try_with_capacity( + encodings.len() + )?; + + for batch_encoding in encodings.into_iter().chunks(16).into_iter() { + store_keys.extend(self.batch_inference_text(batch_encoding.collect())?); + } + Ok(store_keys) + }, } - Ok(store_keys) } } diff --git a/ahnlich/ai/src/engine/ai/providers/ort_helper.rs b/ahnlich/ai/src/engine/ai/providers/ort_helper.rs index e665526d..b9cdaba9 100644 --- a/ahnlich/ai/src/engine/ai/providers/ort_helper.rs +++ b/ahnlich/ai/src/engine/ai/providers/ort_helper.rs @@ -5,14 +5,21 @@ use std::collections::HashMap; use std::fs::File; use std::io::Read; use std::path::PathBuf; +use rayon::prelude::*; /// Public function to read a file to bytes. /// To be used when loading local model files. pub fn read_file_to_bytes(file: &PathBuf) -> Result, AIProxyError> { - let mut file = File::open(file).map_err(|_| AIProxyError::ModelTokenizerLoadError)?; - let file_size = file.metadata().map_err(|_| AIProxyError::ModelTokenizerLoadError)?.len() as usize; + let mut file = File::open(file).map_err(|_| AIProxyError::ModelConfigLoadError { + message: format!("failed to open file {:?}", file), + })?; + let file_size = file.metadata().map_err(|_| AIProxyError::ModelConfigLoadError { + message: format!("failed to get metadata for file {:?}", file), + })?.len() as usize; let mut buffer = Vec::with_capacity(file_size); - file.read_to_end(&mut buffer).map_err(|_| AIProxyError::ModelTokenizerLoadError)?; + file.read_to_end(&mut buffer).map_err(|_| AIProxyError::ModelConfigLoadError { + message: format!("failed to read file {:?}", file), + })?; Ok(buffer) } @@ -33,15 +40,25 @@ impl HFConfigReader { if let Some(value) = self.cache.get(config_name) { return value.clone(); } - let file = self.model_repo.get(config_name).map_err(|_| AIProxyError::ModelConfigLoadError{ - message: format!("failed to parse {}", config_name), + let file = self.model_repo.get(config_name).map_err(|e| AIProxyError::ModelConfigLoadError{ + message: format!("failed to fetch {}, {}", config_name, e.to_string()), + })?; + let contents = read_file_to_bytes(&file).map_err(|e| AIProxyError::ModelConfigLoadError{ + message: format!("failed to read {}, {}", config_name, e.to_string()), })?; - let contents = read_file_to_bytes(&file).unwrap(); let value: serde_json::Value = serde_json::from_slice(&contents).map_err( - |_| AIProxyError::ModelConfigLoadError{ - message: format!("failed to parse {}", config_name), + |e| AIProxyError::ModelConfigLoadError{ + message: format!("failed to parse {}, {}", config_name, e.to_string()), })?; self.cache.insert(config_name.to_string(), Ok(value.clone())); Ok(value) } +} + +pub fn normalize(v: &[f32]) -> Vec { + let norm = (v.par_iter().map(|val| val * val).sum::()).sqrt(); + let epsilon = 1e-12; + + // We add the super-small epsilon to avoid dividing by zero + v.par_iter().map(|&val| val / (norm + epsilon)).collect() } \ No newline at end of file diff --git a/ahnlich/ai/src/engine/ai/providers/ort_text_helper.rs b/ahnlich/ai/src/engine/ai/providers/ort_text_helper.rs deleted file mode 100644 index 95b77f37..00000000 --- a/ahnlich/ai/src/engine/ai/providers/ort_text_helper.rs +++ /dev/null @@ -1,116 +0,0 @@ -// This script was adapted from FastEmbed -// https://github.com/Anush008/fastembed-rs/blob/474d4e62c87666781b580ffc076b8475b693fc34/src/common.rs -use hf_hub::api::sync::ApiRepo; -use rayon::prelude::*; -use tokenizers::decoders::bpe::BPEDecoder; -use tokenizers::{AddedToken, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams}; -use crate::engine::ai::providers::ort_helper; -use crate::error::AIProxyError; - - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct TokenizerFiles { - pub tokenizer_file: Vec, - pub config_file: Vec, - pub special_tokens_map_file: Vec, - pub tokenizer_config_file: Vec, -} - -/// The procedure for loading tokenizer files from the hugging face hub is separated -/// from the main load_tokenizer function (which is expecting bytes, from any source). -pub fn load_tokenizer_artifacts_hf_hub(model_repo: &ApiRepo, max_length: usize) -> Result { - let tokenizer_files: TokenizerFiles = get_tokenizer_artifacts_hf_hub(model_repo)?; - load_tokenizer(tokenizer_files, max_length) -} - -pub fn get_tokenizer_artifacts_hf_hub(model_repo: &ApiRepo) -> Result { - Ok(TokenizerFiles { - tokenizer_file: ort_helper::read_file_to_bytes(&model_repo.get("tokenizer.json") - .map_err(|_| AIProxyError::ModelTokenizerLoadError)?) - .map_err(|_| AIProxyError::ModelTokenizerLoadError)?, - config_file: ort_helper::read_file_to_bytes(&model_repo.get("config.json") - .map_err(|_| AIProxyError::ModelTokenizerLoadError)?) - .map_err(|_| AIProxyError::ModelTokenizerLoadError)?, - special_tokens_map_file: ort_helper::read_file_to_bytes(&model_repo.get("special_tokens_map.json") - .map_err(|_| AIProxyError::ModelTokenizerLoadError)?) - .map_err(|_| AIProxyError::ModelTokenizerLoadError)?, - tokenizer_config_file: ort_helper::read_file_to_bytes(&model_repo.get("tokenizer_config.json") - .map_err(|_| AIProxyError::ModelTokenizerLoadError)?) - .map_err(|_| AIProxyError::ModelTokenizerLoadError)?, - }) -} - -/// Function can be called directly from the try_new_from_user_defined function (providing file bytes) -/// -/// Or indirectly from the try_new function via load_tokenizer_hf_hub (converting HF files to bytes) -pub fn load_tokenizer(tokenizer_files: TokenizerFiles, max_length: usize) -> Result { - // Serialise each tokenizer file - let config: serde_json::Value = - serde_json::from_slice(&tokenizer_files.config_file).map_err(|_| AIProxyError::ModelTokenizerLoadError)?; - let special_tokens_map: serde_json::Value = - serde_json::from_slice(&tokenizer_files.special_tokens_map_file).map_err(|_| AIProxyError::ModelTokenizerLoadError)?; - let tokenizer_config: serde_json::Value = - serde_json::from_slice(&tokenizer_files.tokenizer_config_file).map_err(|_| AIProxyError::ModelTokenizerLoadError)?; - let mut tokenizer: tokenizers::Tokenizer = - tokenizers::Tokenizer::from_bytes(tokenizer_files.tokenizer_file).map_err(|_| AIProxyError::ModelTokenizerLoadError)?; - - //For BGEBaseSmall, the model_max_length value is set to 1000000000000000019884624838656. Which fits in a f64 - let model_max_length = tokenizer_config["model_max_length"] - .as_f64() - .ok_or(AIProxyError::ModelTokenizerLoadError)? as f32; - let max_length = max_length.min(model_max_length as usize); - let pad_id = config["pad_token_id"].as_u64().unwrap_or(0) as u32; - let pad_token = tokenizer_config["pad_token"] - .as_str() - .ok_or(AIProxyError::ModelTokenizerLoadError)? - .into(); - - let mut tokenizer = tokenizer - .with_padding(Some(PaddingParams { - // TODO: the user should able to choose the padding strategy - strategy: PaddingStrategy::BatchLongest, - pad_token, - pad_id, - ..Default::default() - })) - .with_truncation(Some(TruncationParams { - max_length, - ..Default::default() - })) - .map_err(|_| AIProxyError::ModelTokenizerLoadError)? - .clone(); - if let serde_json::Value::Object(root_object) = special_tokens_map { - for (_, value) in root_object.iter() { - if value.is_string() { - tokenizer.add_special_tokens(&[AddedToken { - content: value.as_str().unwrap().into(), - special: true, - ..Default::default() - }]); - } else if value.is_object() { - tokenizer.add_special_tokens(&[AddedToken { - content: value["content"].as_str().unwrap().into(), - special: true, - single_word: value["single_word"].as_bool().unwrap(), - lstrip: value["lstrip"].as_bool().unwrap(), - rstrip: value["rstrip"].as_bool().unwrap(), - normalized: value["normalized"].as_bool().unwrap(), - }]); - } - } - } - - let decoder = BPEDecoder::new("".to_string()); - tokenizer.with_decoder(Some(decoder)); - - Ok(tokenizer.into()) -} - -pub fn normalize(v: &[f32]) -> Vec { - let norm = (v.par_iter().map(|val| val * val).sum::()).sqrt(); - let epsilon = 1e-12; - - // We add the super-small epsilon to avoid dividing by zero - v.par_iter().map(|&val| val / (norm + epsilon)).collect() -} - diff --git a/ahnlich/ai/src/engine/ai/providers/processors/center_crop.rs b/ahnlich/ai/src/engine/ai/providers/processors/center_crop.rs index 968787a2..4ffc6a15 100644 --- a/ahnlich/ai/src/engine/ai/providers/processors/center_crop.rs +++ b/ahnlich/ai/src/engine/ai/providers/processors/center_crop.rs @@ -1,6 +1,6 @@ use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use crate::engine::ai::models::ImageArray; -use crate::engine::ai::providers::processors::{CONV_NEXT_FEATURE_EXTRACTOR_CENTER_CROP_THRESHOLD, Processor, ProcessorData}; +use crate::engine::ai::providers::processors::{CONV_NEXT_FEATURE_EXTRACTOR_CENTER_CROP_THRESHOLD, Preprocessor, PreprocessorData}; use crate::error::AIProxyError; pub struct CenterCrop { @@ -81,10 +81,10 @@ impl TryFrom<&serde_json::Value> for CenterCrop { } } -impl Processor for CenterCrop { - fn process(&self, data: ProcessorData) -> Result { +impl Preprocessor for CenterCrop { + fn process(&self, data: PreprocessorData) -> Result { match data { - ProcessorData::ImageArray(image_array) => { + PreprocessorData::ImageArray(image_array) => { let processed = image_array.par_iter().map(|image| { if !self.process { return Ok(image.clone()); @@ -112,7 +112,7 @@ impl Processor for CenterCrop { } }) .collect::, AIProxyError>>(); - Ok(ProcessorData::ImageArray(processed?)) + Ok(PreprocessorData::ImageArray(processed?)) }, _ => Err(AIProxyError::CenterCropError { message: "CenterCrop process failed. Expected ImageArray, got NdArray3C".to_string(), diff --git a/ahnlich/ai/src/engine/ai/providers/processors/imagearray_to_ndarray.rs b/ahnlich/ai/src/engine/ai/providers/processors/imagearray_to_ndarray.rs index fa53b1bb..4ee46795 100644 --- a/ahnlich/ai/src/engine/ai/providers/processors/imagearray_to_ndarray.rs +++ b/ahnlich/ai/src/engine/ai/providers/processors/imagearray_to_ndarray.rs @@ -1,14 +1,14 @@ use ndarray::{ArrayView, Ix3}; use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator}; -use crate::engine::ai::providers::processors::{Processor, ProcessorData}; +use crate::engine::ai::providers::processors::{Preprocessor, PreprocessorData}; use crate::error::AIProxyError; pub struct ImageArrayToNdArray; -impl Processor for ImageArrayToNdArray { - fn process(&self, data: ProcessorData) -> Result { +impl Preprocessor for ImageArrayToNdArray { + fn process(&self, data: PreprocessorData) -> Result { match data { - ProcessorData::ImageArray(mut arrays) => { + PreprocessorData::ImageArray(mut arrays) => { let array_views: Vec> = arrays .par_iter_mut() .map(|image_arr| { @@ -19,7 +19,7 @@ impl Processor for ImageArrayToNdArray { let pixel_values_array = ndarray::stack(ndarray::Axis(0), &array_views) .map_err(|e| AIProxyError::EmbeddingShapeError(e.to_string()))?; - Ok(ProcessorData::NdArray3C(pixel_values_array)) + Ok(PreprocessorData::NdArray3C(pixel_values_array)) } _ => Err(AIProxyError::ImageArrayToNdArrayError { message: "ImageArrayToNdArray failed. Expected ImageArray, got NdArray3C".to_string(), diff --git a/ahnlich/ai/src/engine/ai/providers/processors/mod.rs b/ahnlich/ai/src/engine/ai/providers/processors/mod.rs index 4f7518f6..cb078275 100644 --- a/ahnlich/ai/src/engine/ai/providers/processors/mod.rs +++ b/ahnlich/ai/src/engine/ai/providers/processors/mod.rs @@ -1,6 +1,7 @@ use crate::engine::ai::models::ImageArray; use crate::error::AIProxyError; -use ndarray::{Array, Ix4}; +use ndarray::{Array, Ix2, Ix3, Ix4}; +use tokenizers::Encoding; pub mod normalize; pub mod resize; @@ -8,14 +9,28 @@ pub mod imagearray_to_ndarray; pub mod center_crop; pub mod rescale; pub mod preprocessor; +pub mod tokenize; +pub mod postprocessor; +pub mod pooling; pub const CONV_NEXT_FEATURE_EXTRACTOR_CENTER_CROP_THRESHOLD: u32 = 384; -pub trait Processor: Send + Sync { - fn process(&self, data: ProcessorData) -> Result; +pub trait Preprocessor: Send + Sync { + fn process(&self, data: PreprocessorData) -> Result; } -pub enum ProcessorData { +pub trait Postprocessor: Send + Sync { + fn process(&self, data: PostprocessorData) -> Result; +} + +pub enum PreprocessorData { ImageArray(Vec), - NdArray3C(Array) + NdArray3C(Array), + Text(Vec), + EncodedText(Vec), +} + +pub enum PostprocessorData { + NdArray2(Array), + NdArray3(Array) } diff --git a/ahnlich/ai/src/engine/ai/providers/processors/normalize.rs b/ahnlich/ai/src/engine/ai/providers/processors/normalize.rs index adea122f..8ef87565 100644 --- a/ahnlich/ai/src/engine/ai/providers/processors/normalize.rs +++ b/ahnlich/ai/src/engine/ai/providers/processors/normalize.rs @@ -1,15 +1,15 @@ use crate::error::AIProxyError; -use crate::engine::ai::providers::processors::{Processor, ProcessorData}; -use ndarray::Array; +use crate::engine::ai::providers::processors::{Postprocessor, PostprocessorData, Preprocessor, PreprocessorData}; +use ndarray::{Array, Axis}; use std::ops::{Div, Sub}; -pub struct Normalize { +pub struct ImageNormalize { mean: Vec, std: Vec, process: bool } -impl TryFrom<&serde_json::Value> for Normalize { +impl TryFrom<&serde_json::Value> for ImageNormalize { type Error = AIProxyError; fn try_from(config: &serde_json::Value) -> Result { @@ -39,14 +39,14 @@ impl TryFrom<&serde_json::Value> for Normalize { } } -impl Processor for Normalize { - fn process(&self, data: ProcessorData) -> Result { +impl Preprocessor for ImageNormalize { + fn process(&self, data: PreprocessorData) -> Result { if !self.process { return Ok(data); } match data { - ProcessorData::NdArray3C(array) => { + PreprocessorData::NdArray3C(array) => { let mean = Array::from_vec(self.mean.clone()) .into_shape_with_order((3, 1, 1)) .unwrap(); @@ -62,7 +62,7 @@ impl Processor for Normalize { let array_normalized = array .sub(mean_broadcast) .div(std_broadcast); - Ok(ProcessorData::NdArray3C(array_normalized)) + Ok(PreprocessorData::NdArray3C(array_normalized)) } _ => Err(AIProxyError::ImageNormalizationError { message: format!("Image normalization failed due to invalid shape for image array; \ @@ -75,4 +75,30 @@ impl Processor for Normalize { }), } } +} + +pub struct VectorNormalize; + +impl Postprocessor for VectorNormalize { + fn process(&self, data: PostprocessorData) -> Result { + match data { + PostprocessorData::NdArray2(array) => { + let norm = (&array * &array).sum_axis(Axis(1)).sqrt(); + let epsilon = 1e-12; + let regularized_norm = norm + epsilon; + let regularized_norm = regularized_norm.insert_axis(Axis(1)); + let source_shape = regularized_norm.shape(); + let target_shape = array.shape(); + let broadcasted_norm = regularized_norm + .broadcast(array.dim()).ok_or(AIProxyError::VectorNormalizationError { + message: format!("Could not broadcast attention mask with shape {:?} to \ + shape {:?} of the input tensor.", source_shape, target_shape), + })?.to_owned(); + Ok(PostprocessorData::NdArray2(array / broadcasted_norm)) + } + _ => Err(AIProxyError::VectorNormalizationError { + message: "Expected NdArray2, got NdArray3".to_string(), + }), + } + } } \ No newline at end of file diff --git a/ahnlich/ai/src/engine/ai/providers/processors/pooling.rs b/ahnlich/ai/src/engine/ai/providers/processors/pooling.rs new file mode 100644 index 00000000..9484b7b2 --- /dev/null +++ b/ahnlich/ai/src/engine/ai/providers/processors/pooling.rs @@ -0,0 +1,71 @@ +use ndarray::{s, Array, Array2, ArrayView, Dim, Axis, Dimension, IxDynImpl, Ix2}; +use crate::engine::ai::providers::processors::{Postprocessor, PostprocessorData}; +use crate::error::AIProxyError; + +pub enum Pooling { + Regular(RegularPooling), + Mean(MeanPooling), +} + +pub struct RegularPooling; + +impl Postprocessor for RegularPooling { + fn process(&self, data: PostprocessorData) -> Result { + match data { + PostprocessorData::NdArray3(array) => { + let processed = array.slice(s![.., 0, ..]).to_owned(); + Ok(PostprocessorData::NdArray2(processed)) + } + _ => Err(AIProxyError::PoolingError { + message: "Expected NdArray3, got NdArray2".to_string(), + }), + } + } +} + +pub struct MeanPooling { + attention_mask: Option> +} + +impl MeanPooling { + pub fn new() -> Self { + Self { attention_mask: None } + } + + pub fn set_attention_mask(&mut self, attention_mask: Option>) { + self.attention_mask = attention_mask; + } +} + +impl Postprocessor for MeanPooling { + fn process(&self, data: PostprocessorData) -> Result { + match data { + PostprocessorData::NdArray3(array) => { + let attention_mask = match &self.attention_mask { + Some(mask) => { + let attention_mask = mask.mapv(|x| x as f32); + attention_mask + .insert_axis(Axis(2)) + .broadcast(array.dim()).ok_or( + AIProxyError::PoolingError { + message: format!("Could not broadcast attention mask with shape {:?} to \ + shape {:?} of the input tensor.", mask.shape(), array.shape()), + } + )?.to_owned() + }, + None => Array::ones(array.dim()), + }; + + let masked_array = &attention_mask * &array; + let masked_array_sum = masked_array.sum_axis(Axis(1)); + let attention_mask_sum = attention_mask.sum_axis(Axis(1)); + let min_value = 1e-9; + let attention_mask_sum = attention_mask_sum.mapv(|x| x.max(min_value)); + Ok(PostprocessorData::NdArray2(&masked_array_sum / &attention_mask_sum)) + } + _ => Err(AIProxyError::PoolingError { + message: "Expected NdArray3, got NdArray2".to_string(), + }), + } + } +} diff --git a/ahnlich/ai/src/engine/ai/providers/processors/postprocessor.rs b/ahnlich/ai/src/engine/ai/providers/processors/postprocessor.rs new file mode 100644 index 00000000..2bf2d504 --- /dev/null +++ b/ahnlich/ai/src/engine/ai/providers/processors/postprocessor.rs @@ -0,0 +1,81 @@ +use std::sync::{Arc, Mutex}; +use ndarray::{Array, Ix2, Ix3}; +use crate::cli::server::SupportedModels; +use crate::engine::ai::providers::processors::normalize::VectorNormalize; +use crate::engine::ai::providers::processors::pooling::{Pooling, RegularPooling, MeanPooling}; +use crate::engine::ai::providers::processors::{PostprocessorData, Postprocessor}; +use crate::error::AIProxyError; + +pub enum ORTPostprocessor { + Image(ORTImagePostprocessor), + Text(ORTTextPostprocessor), +} + +pub struct ORTTextPostprocessor { + model: SupportedModels, + pooling: Arc>, + normalize: Option +} + + +impl ORTTextPostprocessor { + pub fn load(supported_model: SupportedModels) -> Result { + let artifacts = match supported_model { + SupportedModels::AllMiniLML6V2 | + SupportedModels::AllMiniLML12V2 => Ok(( + Pooling::Mean(MeanPooling::new()), + Some(VectorNormalize) + )), + SupportedModels::BGEBaseEnV15 | + SupportedModels::BGELargeEnV15 => Ok(( + Pooling::Regular(RegularPooling), + Some(VectorNormalize) + )), + SupportedModels::ClipVitB32Text => Ok(( + Pooling::Mean(MeanPooling::new()), + None + )), + _ => Err(AIProxyError::ModelPostprocessingError { + model_name: supported_model.to_string(), + message: "Unsupported model for ORTTextPostprocessor".to_string(), + }) + }?; + Ok(Self { + model: supported_model, + pooling: Arc::new(Mutex::new(artifacts.0)), + normalize: artifacts.1 + }) + } + + pub fn process(&self, embeddings: Array, attention_mask: Array) -> Result, AIProxyError> { + let mut pooling = self.pooling.lock().map_err(|_| AIProxyError::ModelPostprocessingError { + model_name: self.model.to_string(), + message: "Failed to acquire lock on pooling.".to_string() + })?; + let pooled = match &mut *pooling { + Pooling::Regular(pooling) => { + pooling.process(PostprocessorData::NdArray3(embeddings))? + }, + Pooling::Mean(pooling) => { + pooling.set_attention_mask(Some(attention_mask)); + pooling.process(PostprocessorData::NdArray3(embeddings))? + } + }; + let result = match &self.normalize { + Some(normalize) => normalize.process(pooled), + None => Ok(pooled) + }?; + match result { + PostprocessorData::NdArray2(array) => Ok(array), + _ => Err(AIProxyError::ModelPostprocessingError { + model_name: self.model.to_string(), + message: "Expected NdArray2, got NdArray3".to_string() + }) + } + } +} + +struct ORTImagePostprocessor { + model: SupportedModels, + normalize: VectorNormalize +} \ No newline at end of file diff --git a/ahnlich/ai/src/engine/ai/providers/processors/preprocessor.rs b/ahnlich/ai/src/engine/ai/providers/processors/preprocessor.rs index da3d83b1..a78db697 100644 --- a/ahnlich/ai/src/engine/ai/providers/processors/preprocessor.rs +++ b/ahnlich/ai/src/engine/ai/providers/processors/preprocessor.rs @@ -1,22 +1,25 @@ use std::iter; +use std::sync::{Arc, Mutex}; use hf_hub::api::sync::ApiRepo; use ndarray::{Array, Ix4}; -use tokenizers::Tokenizer; +use tokenizers::{Encoding, Tokenizer}; use crate::engine::ai::models::ImageArray; use crate::engine::ai::providers::ort_helper::HFConfigReader; use crate::engine::ai::providers::processors::center_crop::CenterCrop; use crate::engine::ai::providers::processors::imagearray_to_ndarray::ImageArrayToNdArray; -use crate::engine::ai::providers::processors::normalize::Normalize; -use crate::engine::ai::providers::processors::{Processor, ProcessorData}; +use crate::engine::ai::providers::processors::normalize::ImageNormalize; +use crate::engine::ai::providers::processors::{Preprocessor, PreprocessorData}; use crate::engine::ai::providers::processors::rescale::Rescale; use crate::engine::ai::providers::processors::resize::Resize; +use crate::engine::ai::providers::processors::tokenize::Tokenize; use crate::error::AIProxyError; +#[derive(Clone)] pub struct ImagePreprocessorFiles { - resize: Option, - normalize: Option, - rescale: Option, - center_crop: Option, + pub resize: Option, + pub normalize: Option, + pub rescale: Option, + pub center_crop: Option, } impl ImagePreprocessorFiles { @@ -44,18 +47,47 @@ impl Default for ImagePreprocessorFiles { } } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TokenizerFiles { + pub tokenizer_file: String, + pub config_file: String, + pub special_tokens_map_file: String, + pub tokenizer_config_file: String, +} + +impl Default for TokenizerFiles { + fn default() -> Self { + Self { + tokenizer_file: "tokenizer.json".to_string(), + config_file: "config.json".to_string(), + special_tokens_map_file: "special_tokens_map.json".to_string(), + tokenizer_config_file: "tokenizer_config.json".to_string(), + } + } +} + +#[derive(Default, Clone)] +pub struct TextPreprocessorFiles { + pub tokenize: TokenizerFiles, +} + +pub enum ORTPreprocessor { + Image(ORTImagePreprocessor), + Text(ORTTextPreprocessor), +} + #[derive(Default)] pub struct ORTImagePreprocessor { - imagearray_to_ndarray: Option>, - normalize: Option>, - resize: Option>, - rescale: Option>, - center_crop: Option>, + imagearray_to_ndarray: Option>, + normalize: Option>, + resize: Option>, + rescale: Option>, + center_crop: Option>, } impl ORTImagePreprocessor { pub fn iter(&self) -> impl Iterator)> { + &str, &Box)> { iter::empty() .chain(self.resize.as_ref().map( |f| ("resize", f))) @@ -87,7 +119,7 @@ impl ORTImagePreprocessor { self.resize = Some(Box::new(Resize::try_from(&config.expect("Config exists"))?)); } "normalize" => { - self.normalize = Some(Box::new(Normalize::try_from(&config.expect("Config exists"))?)); + self.normalize = Some(Box::new(ImageNormalize::try_from(&config.expect("Config exists"))?)); } "rescale" => { self.rescale = Some(Box::new(Rescale::try_from(&config.expect("Config exists"))?)); @@ -104,12 +136,12 @@ impl ORTImagePreprocessor { } pub fn process(&self, data: Vec) -> Result, AIProxyError> { - let mut data = ProcessorData::ImageArray(data); + let mut data = PreprocessorData::ImageArray(data); for (_, processor) in self.iter() { data = processor.process(data)?; } match data { - ProcessorData::NdArray3C(array) => Ok(array), + PreprocessorData::NdArray3C(array) => Ok(array), _ => Err(AIProxyError::ModelProviderPreprocessingError( "Expected NdArray after processing".to_string() )) @@ -117,11 +149,35 @@ impl ORTImagePreprocessor { } } -pub enum ORTPreprocessor { - Image(ORTImagePreprocessor), - Text(ORTTextPreprocessor), +pub struct ORTTextPreprocessor { + pub tokenize: Arc> } -pub struct ORTTextPreprocessor { - pub tokenizer: Tokenizer, +impl ORTTextPreprocessor { + pub fn load(model_repo: ApiRepo, processor_files: TextPreprocessorFiles) -> Result { + Ok( + ORTTextPreprocessor { + tokenize: Arc::new(Mutex::new( + Tokenize::initialize(processor_files.tokenize, model_repo)?, + )), + } + ) + } + + pub fn process(&self, data: Vec, truncate: bool) -> Result, AIProxyError> { + let mut data = PreprocessorData::Text(data); + let mut tokenize = self.tokenize.lock().map_err(|_| { + AIProxyError::ModelProviderPreprocessingError( + "Failed to acquire lock on tokenizer".to_string(), + ) + })?; + tokenize.set_truncate(truncate); + data = tokenize.process(data)?; + match data { + PreprocessorData::EncodedText(encodings) => Ok(encodings), + _ => Err(AIProxyError::ModelProviderPreprocessingError( + "Expected EncodedText after processing".to_string() + )) + } + } } \ No newline at end of file diff --git a/ahnlich/ai/src/engine/ai/providers/processors/rescale.rs b/ahnlich/ai/src/engine/ai/providers/processors/rescale.rs index 5b3ed2fb..48e74bf6 100644 --- a/ahnlich/ai/src/engine/ai/providers/processors/rescale.rs +++ b/ahnlich/ai/src/engine/ai/providers/processors/rescale.rs @@ -1,4 +1,4 @@ -use crate::engine::ai::providers::processors::{Processor, ProcessorData}; +use crate::engine::ai::providers::processors::{Preprocessor, PreprocessorData}; use crate::error::AIProxyError; pub struct Rescale { @@ -25,17 +25,17 @@ impl TryFrom<&serde_json::Value> for Rescale { } } -impl Processor for Rescale { - fn process(&self, data: ProcessorData) -> Result { +impl Preprocessor for Rescale { + fn process(&self, data: PreprocessorData) -> Result { if !self.process { return Ok(data); } match data { - ProcessorData::NdArray3C(array) => { + PreprocessorData::NdArray3C(array) => { let mut array = array; array *= self.scale; - Ok(ProcessorData::NdArray3C(array)) + Ok(PreprocessorData::NdArray3C(array)) }, _ => Err(AIProxyError::RescaleError { message: "Rescale process failed. Expected NdArray3C, got ImageArray".to_string(), diff --git a/ahnlich/ai/src/engine/ai/providers/processors/resize.rs b/ahnlich/ai/src/engine/ai/providers/processors/resize.rs index a670db38..e46dbb53 100644 --- a/ahnlich/ai/src/engine/ai/providers/processors/resize.rs +++ b/ahnlich/ai/src/engine/ai/providers/processors/resize.rs @@ -1,7 +1,7 @@ use image::imageops::FilterType; use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator}; use crate::engine::ai::models::ImageArray; -use crate::engine::ai::providers::processors::{CONV_NEXT_FEATURE_EXTRACTOR_CENTER_CROP_THRESHOLD, Processor, ProcessorData}; +use crate::engine::ai::providers::processors::{CONV_NEXT_FEATURE_EXTRACTOR_CENTER_CROP_THRESHOLD, Preprocessor, PreprocessorData}; use crate::error::AIProxyError; pub struct Resize { @@ -84,10 +84,10 @@ impl TryFrom<&serde_json::Value> for Resize { } } -impl Processor for Resize { - fn process(&self, data: ProcessorData) -> Result { +impl Preprocessor for Resize { + fn process(&self, data: PreprocessorData) -> Result { match data { - ProcessorData::ImageArray(mut arrays) => { + PreprocessorData::ImageArray(mut arrays) => { let processed = arrays.par_iter_mut() .map(|image| { if !self.process { @@ -98,7 +98,7 @@ impl Processor for Resize { Ok(image) }) .collect::, AIProxyError>>(); - Ok(ProcessorData::ImageArray(processed?)) + Ok(PreprocessorData::ImageArray(processed?)) } _ => Err(AIProxyError::ImageArrayToNdArrayError { message: "Resize failed. Expected ImageArray, got NdArray3C".to_string(), diff --git a/ahnlich/ai/src/engine/ai/providers/processors/tokenize.rs b/ahnlich/ai/src/engine/ai/providers/processors/tokenize.rs new file mode 100644 index 00000000..0095e86e --- /dev/null +++ b/ahnlich/ai/src/engine/ai/providers/processors/tokenize.rs @@ -0,0 +1,144 @@ +use hf_hub::api::sync::ApiRepo; +use serde_json::Value; +use tokenizers::{AddedToken, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams}; +use tokenizers::decoders::bpe::BPEDecoder; +use crate::engine::ai::providers::ort_helper::{HFConfigReader, read_file_to_bytes}; +use crate::engine::ai::providers::processors::preprocessor::TokenizerFiles; +use crate::engine::ai::providers::processors::{Preprocessor, PreprocessorData}; +use crate::error::AIProxyError; + + +pub struct Tokenize { + tokenizer: Tokenizer, + model_max_length: usize, + truncate: bool +} + +pub struct TokenizeArtifacts { + pub tokenizer_bytes: Vec, + pub config: Value, + pub special_tokens_map: Value, + pub tokenizer_config: Value, +} + +impl Tokenize { + pub fn download_artifacts(tokenizer_files: TokenizerFiles, model_repo: ApiRepo) -> Result { + let tokenizer_bytes = read_file_to_bytes( + &model_repo.get(&tokenizer_files.tokenizer_file) + .map_err(|e| AIProxyError::ModelConfigLoadError{ + message: format!("failed to fetch {}, {}", &tokenizer_files.tokenizer_file, e.to_string()), + })?)?; + let mut config_reader = HFConfigReader::new(model_repo); + let config = config_reader.read(&tokenizer_files.config_file)?; + let special_tokens_map = config_reader.read(&tokenizer_files.special_tokens_map_file)?; + let tokenizer_config = config_reader.read(&tokenizer_files.tokenizer_config_file)?; + Ok(TokenizeArtifacts { + tokenizer_bytes, + config, + special_tokens_map, + tokenizer_config + }) + } + + pub fn initialize(tokenizer_files: TokenizerFiles, model_repo: ApiRepo) -> Result { + let artifacts = Self::download_artifacts(tokenizer_files, model_repo)?; + let mut tokenizer = + Tokenizer::from_bytes(artifacts.tokenizer_bytes).map_err(|_| AIProxyError::ModelTokenizerLoadError { + message: "Error building Tokenizer from bytes.".to_string(), + })?; + + //For BGEBaseSmall, the model_max_length value is set to 1000000000000000019884624838656. Which fits in a f64 + let model_max_length = artifacts.tokenizer_config["model_max_length"] + .as_f64() + .ok_or(AIProxyError::ModelTokenizerLoadError { + message: "Error reading model_max_length from tokenizer_config".to_string(), + })? as usize; + let pad_id = artifacts.config["pad_token_id"].as_u64().unwrap_or(0) as u32; + let pad_token = artifacts.tokenizer_config["pad_token"] + .as_str() + .ok_or(AIProxyError::ModelTokenizerLoadError { + message: "Error reading pad_token from tokenizer_config".to_string(), + })? + .into(); + + let mut tokenizer = tokenizer + .with_padding(Some(PaddingParams { + // TODO: the user should able to choose the padding strategy + strategy: PaddingStrategy::BatchLongest, + pad_token, + pad_id, + ..Default::default() + })) + .with_truncation(Some(TruncationParams { + max_length: model_max_length, + ..Default::default() + })) + .map_err(|_| AIProxyError::ModelTokenizerLoadError { + message: "Error setting padding and truncation params.".to_string(), + })? + .clone(); + if let serde_json::Value::Object(root_object) = artifacts.special_tokens_map { + for (_, value) in root_object.iter() { + if value.is_string() { + tokenizer.add_special_tokens(&[AddedToken { + content: value.as_str().unwrap().into(), + special: true, + ..Default::default() + }]); + } else if value.is_object() { + tokenizer.add_special_tokens(&[AddedToken { + content: value["content"].as_str().unwrap().into(), + special: true, + single_word: value["single_word"].as_bool().unwrap(), + lstrip: value["lstrip"].as_bool().unwrap(), + rstrip: value["rstrip"].as_bool().unwrap(), + normalized: value["normalized"].as_bool().unwrap(), + }]); + } + } + } + + let decoder = BPEDecoder::new("".to_string()); + tokenizer.with_decoder(Some(decoder)); + + Ok(Self { tokenizer: tokenizer.into(), model_max_length, truncate: true }) + } + + pub fn set_truncate(&mut self, truncate: bool) -> Result<(), AIProxyError> { + let tokenizer; + if truncate { + tokenizer = self.tokenizer.with_truncation( + Some(TruncationParams { + max_length: self.model_max_length, + ..Default::default() + })).map_err(|_| AIProxyError::ModelTokenizerLoadError { + message: "Error setting truncation params.".to_string(), + })?; + } else { + tokenizer = self.tokenizer.with_truncation(None) + .map_err(|_| AIProxyError::ModelTokenizerLoadError { + message: "Error removing truncation params.".to_string(), + })?; + } + self.truncate = truncate; + self.tokenizer = tokenizer.clone().into(); + Ok(()) + } +} + +impl Preprocessor for Tokenize { + fn process(&self, data: PreprocessorData) -> Result { + match data { + PreprocessorData::Text(text) => { + let tokenized = self.tokenizer.encode_batch(text.clone(), true) + .map_err(|_| AIProxyError::ModelTokenizationError { + message: format!("Tokenize process failed. Texts: {:?}", text), + })?; + Ok(PreprocessorData::EncodedText(tokenized)) + }, + _ => Err(AIProxyError::ModelTokenizationError { + message: format!("Tokenize process failed. Expected Text."), + }), + } + } +} diff --git a/ahnlich/ai/src/error.rs b/ahnlich/ai/src/error.rs index 7b77b215..3c78234b 100644 --- a/ahnlich/ai/src/error.rs +++ b/ahnlich/ai/src/error.rs @@ -32,14 +32,28 @@ pub enum AIProxyError { }, #[error("Shape error {0}")] EmbeddingShapeError(String), - #[error("Max Token Exceeded. Model Expects [{max_token_size}], input type was [{input_token_size}] ")] + #[error("Max Token Exceeded. Model Expects [{max_token_size}], input type was [{input_token_size}].")] TokenExceededError { max_token_size: usize, input_token_size: usize, }, - #[error("Model does not support token truncation.")] - TokenTruncationNotSupported, + #[error("Model preprocessing for {model_name} failed: {message}.")] + ModelPreprocessingError { + model_name: String, + message: String + }, + + #[error("Model postprocessing for {model_name} failed: {message}.")] + ModelPostprocessingError { + model_name: String, + message: String + }, + + #[error("Pooling operation failed: {message}.")] + PoolingError { + message: String + }, #[error( "Image Dimensions [({0}, {1})] does not match the expected model dimensions [({2}, {3})]", @@ -51,7 +65,7 @@ pub enum AIProxyError { }, #[error("Error initializing text embedding")] TextEmbeddingInitError(String), - #[error("API Builder Error {0}")] + #[error("API Builder Error: {0}")] APIBuilderError(String), #[error("Tokenizer initialization error {0}")] TokenizerInitError(String), @@ -80,7 +94,12 @@ pub enum AIProxyError { model_name: String }, - #[error("Normalization error: [{message}]")] + #[error("Vector normalization error: [{message}]")] + VectorNormalizationError { + message: String + }, + + #[error("Image normalization error: [{message}]")] ImageNormalizationError { message: String }, @@ -138,23 +157,24 @@ pub enum AIProxyError { #[error("Model provider failed on preprocessing the input {0}")] ModelProviderPreprocessingError(String), - #[error("Inference can only be run on one of text or image inputs, not both.")] - VaryingInferenceInputTypes, - #[error("Model provider failed on running inference {0}")] ModelProviderRunInferenceError(String), #[error("Model provider failed on postprocessing the output {0}")] ModelProviderPostprocessingError(String), - #[error("Model provider failed on tokenization of text inputs.")] - ModelTokenizationError, + #[error("Tokenize error: {message}")] + ModelTokenizationError { + message: String + }, #[error("Cannot call DelKey on store with `store_original` as false")] DelKeyError, - #[error("Tokenizer for model failed on loading.")] - ModelTokenizerLoadError, + #[error("Tokenizer for model failed to load: {message}")] + ModelTokenizerLoadError { + message: String + }, #[error("Unable to load config: [{message}].")] ModelConfigLoadError{ diff --git a/ahnlich/ai/src/manager/mod.rs b/ahnlich/ai/src/manager/mod.rs index 059c0a31..b445ec12 100644 --- a/ahnlich/ai/src/manager/mod.rs +++ b/ahnlich/ai/src/manager/mod.rs @@ -7,13 +7,15 @@ use crate::engine::ai::models::{ImageArray, InputAction}; /// lets AIProxyTasks communicate with any model to receive immediate responses via a oneshot /// channel use crate::engine::ai::models::{Model, ModelInput}; -use crate::engine::ai::providers::{ModelProviders, TextPreprocessorTrait}; +use crate::engine::ai::providers::ModelProviders; +use crate::engine::ai::providers::ort::ORTProvider; use crate::error::AIProxyError; -use ahnlich_types::ai::{AIModel, AIStoreInputType, ImageAction, PreprocessAction, StringAction}; +use ahnlich_types::ai::{AIModel, AIStoreInputType, PreprocessAction}; use ahnlich_types::keyval::{StoreInput, StoreKey}; use fallible_collections::FallibleVec; use moka::future::Cache; use rayon::prelude::*; +use tokenizers::Encoding; use task_manager::Task; use task_manager::TaskManager; use task_manager::TaskState; @@ -77,109 +79,111 @@ impl ModelThread { &self, process_action: PreprocessAction, inputs: Vec, - ) -> Result, AIProxyError> { - let preprocessed_inputs = inputs - .into_par_iter() - .try_fold(Vec::new, |mut accumulator, input| { - let model_input = ModelInput::try_from(input)?; - let processed_input = match (process_action, model_input) { - (PreprocessAction::Image(image_action), ModelInput::Image(image_array)) => { - let output = self.process_image(image_array, image_action)?; - Ok(ModelInput::Image(output)) - } - (PreprocessAction::RawString(string_action), ModelInput::Text(string)) => { - let output = self.preprocess_raw_string(string, string_action)?; - Ok(ModelInput::Text(output)) - } - (_, model_input) => Err(AIProxyError::PreprocessingMismatchError { - input_type: (&model_input).into(), - preprocess_action: process_action, - }), - }?; - accumulator.push(processed_input); - Ok::, AIProxyError>(accumulator) - }) - .try_reduce(Vec::new, |mut accumulator, mut item| { - accumulator.append(&mut item); - Ok(accumulator) - })?; - Ok(preprocessed_inputs) + ) -> Result { + let sample = inputs.first().ok_or(AIProxyError::ModelPreprocessingError { + model_name: self.model.model_name(), + message: "Input is empty".to_string(), + })?; + match sample { + StoreInput::RawString(_) => { + let inputs_inner: Vec = inputs.into_par_iter().filter_map(|input| match input { + StoreInput::RawString(string) => Some(string), + _ => None, + }).collect(); + let output = self.preprocess_raw_string(inputs_inner, process_action)?; + Ok(ModelInput::Texts(output)) + } + StoreInput::Image(_) => { + // let inputs_inner: Vec> = inputs.into_par_iter().filter_map(|input| match input { + // StoreInput::Image(image_bytes) => Some(image_bytes), + // _ => None, + // }).collect(); + // let output = self.preprocess_image(inputs_inner, process_action)?; + Ok(ModelInput::Images(vec![])) + } + } } - #[tracing::instrument(skip(self, input))] + #[tracing::instrument(skip(self, inputs))] fn preprocess_raw_string( &self, - input: String, - string_action: StringAction, - ) -> Result { - let max_token_size = self.model.max_input_token().ok_or_else(|| { - AIProxyError::AIModelInvalidOperation { - model_name: self.model.model_name(), - operation: "[max_input_token] function".to_string() - } - })?; - + inputs: Vec, + process_action: PreprocessAction, + ) -> Result, AIProxyError> { if self.model.input_type() != AIStoreInputType::RawString { - return Err(AIProxyError::TokenTruncationNotSupported); + return Err(AIProxyError::ModelPreprocessingError { + model_name: self.model.model_name(), + message: "RawString preprocessing is not supported.".to_string(), + }); } - let process = |provider: &dyn TextPreprocessorTrait, input| { - let mut tokens = provider.encode_str(input)?; - let max_token_size: usize = max_token_size.into(); - if (tokens.len() > max_token_size) && - (string_action == StringAction::ErrorIfTokensExceed) { - Err(AIProxyError::TokenExceededError { - input_token_size: tokens.len(), - max_token_size, - }) - } else { - tokens.truncate(max_token_size); - let processed_input = provider.decode_tokens(tokens)?; - Ok(processed_input) + let max_token_size = usize::from(self.model.max_input_token().ok_or_else(|| { + AIProxyError::AIModelInvalidOperation { + model_name: self.model.model_name(), + operation: "[max_input_token] function".to_string() } - }; + })?); match &self.model.provider { - ModelProviders::FastEmbed(provider) => { - process(provider, &input) - }, ModelProviders::ORT(provider) => { - process(provider, &input) + let truncate = match process_action { + PreprocessAction::ModelPreprocessing => true, + _ => false + }; + let outputs = provider.preprocess_texts(inputs, truncate)?; + let token_size = outputs.first().ok_or(AIProxyError::ModelPreprocessingError { + model_name: self.model.model_name(), + message: "Processed output is empty".to_string(), + })?.len(); + if token_size > max_token_size { + return Err(AIProxyError::TokenExceededError { + max_token_size, + input_token_size: token_size, + }); + } else { + return Ok(outputs); + } } } } #[tracing::instrument(skip(self, input))] - fn process_image( + fn preprocess_image( &self, input: ImageArray, - image_action: ImageAction, + process_action: PreprocessAction + // input: Vec>, ) -> Result { + Err(AIProxyError::ModelPreprocessingError { + model_name: self.model.model_name(), + message: "Image preprocessing is not supported.".to_string(), + }) // process image, return error if max dimensions exceeded - let dimensions = input.image_dim(); - - let preprocess_mismatch = Err(AIProxyError::PreprocessingMismatchError { - input_type: AIStoreInputType::Image, - preprocess_action: PreprocessAction::Image(image_action), - }); - - let Some((expected_width, expected_height)) = self.model.expected_image_dimensions() else { - return preprocess_mismatch; - }; - - let (width, height) = dimensions; - - if width != expected_width || height != expected_height { - if let ImageAction::ErrorIfDimensionsMismatch = image_action { - return Err(AIProxyError::ImageDimensionsMismatchError { - image_dimensions: (width.into(), height.into()), - expected_dimensions: (expected_width.into(), expected_height.into()), - }); - } else { - return Ok(input); - } - } - - Ok(input) + // let dimensions = input.image_dim(); + // + // let Some((expected_width, expected_height)) = + // self.model.expected_image_dimensions() + // .ok_or( + // Err(AIProxyError::PreprocessingMismatchError { + // input_type: AIStoreInputType::Image, + // preprocess_action: process_action, + // }))?; + // + // match process_action { + // PreprocessAction::NoPreprocessing => { + // Ok(input) + // } + // PreprocessAction::ModelPreprocessing => { + // let (width, height) = dimensions; + // if width != expected_width || height != expected_height { + // Err(AIProxyError::ImageDimensionsMismatchError { + // image_dimensions: (width.into(), height.into()), + // expected_dimensions: (expected_width.into(), expected_height.into()), + // }) + // } else { + // Ok(input) + // } + // } + // } } } @@ -348,7 +352,7 @@ mod tests { let evicted_model = model_manager.models.get(&sample_supported_model).await; let inputs = vec![StoreInput::RawString(String::from("Hello"))]; - let action = PreprocessAction::RawString(StringAction::TruncateIfTokensExceed); + let action = PreprocessAction::ModelPreprocessing; let _ = model_manager .handle_request(&sample_ai_model, inputs, action, InputAction::Query) .await diff --git a/ahnlich/ai/src/server/task.rs b/ahnlich/ai/src/server/task.rs index b578dc9a..8810bf41 100644 --- a/ahnlich/ai/src/server/task.rs +++ b/ahnlich/ai/src/server/task.rs @@ -1,8 +1,7 @@ use crate::engine::ai::models::Model; use ahnlich_client_rs::db::DbClient; use ahnlich_types::ai::{ - AIQuery, AIServerQuery, AIServerResponse, AIServerResult, ImageAction, PreprocessAction, - StringAction, + AIQuery, AIServerQuery, AIServerResponse, AIServerResult, PreprocessAction }; use ahnlich_types::client::ConnectedClient; use ahnlich_types::db::{ServerInfo, ServerResponse}; @@ -324,10 +323,8 @@ impl AhnlichProtocol for AIProxyTask { // TODO: Replace this with calls to self.model_manager.handle_request // TODO (HAKSOAT): Shouldn't preprocess action also be in the params? let preprocess = match search_input { - StoreInput::RawString(_) => { - PreprocessAction::RawString(StringAction::TruncateIfTokensExceed) - } - StoreInput::Image(_) => PreprocessAction::Image(ImageAction::ResizeImage), + StoreInput::RawString(_) => PreprocessAction::ModelPreprocessing, + StoreInput::Image(_) => PreprocessAction::ModelPreprocessing, }; let repr = self .store_handler diff --git a/ahnlich/ai/src/tests/aiproxy_test.rs b/ahnlich/ai/src/tests/aiproxy_test.rs index ccedba81..7f66d8a9 100644 --- a/ahnlich/ai/src/tests/aiproxy_test.rs +++ b/ahnlich/ai/src/tests/aiproxy_test.rs @@ -3,7 +3,7 @@ use ahnlich_db::server::handler::Server; use ahnlich_types::{ ai::{ AIModel, AIQuery, AIServerQuery, AIServerResponse, AIServerResult, AIStoreInfo, - ImageAction, PreprocessAction, StringAction, + PreprocessAction }, db::StoreUpsert, keyval::{StoreInput, StoreName, StoreValue}, @@ -208,7 +208,7 @@ async fn test_ai_store_no_original() { AIQuery::Set { store: store_name.clone(), inputs: store_data.clone(), - preprocess_action: PreprocessAction::RawString(StringAction::ErrorIfTokensExceed), + preprocess_action: PreprocessAction::NoPreprocessing, }, ]); let mut reader = BufReader::new(first_stream); @@ -282,7 +282,7 @@ async fn test_ai_proxy_get_pred_succeeds() { AIQuery::Set { store: store_name.clone(), inputs: store_data.clone(), - preprocess_action: PreprocessAction::RawString(StringAction::ErrorIfTokensExceed), + preprocess_action: PreprocessAction::NoPreprocessing, }, ]); let mut reader = BufReader::new(first_stream); @@ -362,7 +362,7 @@ async fn test_ai_proxy_get_sim_n_succeeds() { AIQuery::Set { store: store_name.clone(), inputs: store_data.clone(), - preprocess_action: PreprocessAction::RawString(StringAction::ErrorIfTokensExceed), + preprocess_action: PreprocessAction::NoPreprocessing, }, ]); let mut reader = BufReader::new(first_stream); @@ -429,7 +429,7 @@ async fn test_ai_proxy_create_drop_pred_index() { AIQuery::Set { store: store_name.clone(), inputs: store_data.clone(), - preprocess_action: PreprocessAction::RawString(StringAction::ErrorIfTokensExceed), + preprocess_action: PreprocessAction::NoPreprocessing, }, AIQuery::GetPred { store: store_name.clone(), @@ -500,7 +500,7 @@ async fn test_ai_proxy_del_key_drop_store() { AIQuery::Set { store: store_name.clone(), inputs: store_data.clone(), - preprocess_action: PreprocessAction::RawString(StringAction::ErrorIfTokensExceed), + preprocess_action: PreprocessAction::NoPreprocessing, }, AIQuery::DelKey { store: store_name.clone(), @@ -791,13 +791,13 @@ async fn test_ai_proxy_binary_store_actions() { AIQuery::Set { store: store_name.clone(), inputs: store_data, - preprocess_action: PreprocessAction::Image(ImageAction::ErrorIfDimensionsMismatch), + preprocess_action: PreprocessAction::NoPreprocessing, }, // all dimensions match 224x224 so no error AIQuery::Set { store: store_name.clone(), inputs: oversize_data, - preprocess_action: PreprocessAction::Image(ImageAction::ErrorIfDimensionsMismatch), + preprocess_action: PreprocessAction::NoPreprocessing, }, // expect an error as the dimensions do not match 224x224 AIQuery::DropPredIndex { @@ -889,7 +889,7 @@ async fn test_ai_proxy_binary_store_set_text_and_binary_fails() { AIQuery::Set { store: store_name.clone(), inputs: store_data, - preprocess_action: PreprocessAction::RawString(StringAction::ErrorIfTokensExceed), + preprocess_action: PreprocessAction::NoPreprocessing, }, AIQuery::PurgeStores, ]); diff --git a/ahnlich/client/src/ai.rs b/ahnlich/client/src/ai.rs index f7e2ef05..b4b68f4f 100644 --- a/ahnlich/client/src/ai.rs +++ b/ahnlich/client/src/ai.rs @@ -596,7 +596,7 @@ mod tests { HashMap::new() ), ], - PreprocessAction::RawString(StringAction::ErrorIfTokensExceed), + PreprocessAction::NoPreprocessing, None ) .await @@ -750,7 +750,7 @@ mod tests { pipeline.set( store_name.clone(), store_data, - PreprocessAction::RawString(StringAction::ErrorIfTokensExceed), + PreprocessAction::NoPreprocessing, ); pipeline.drop_pred_index( @@ -876,7 +876,7 @@ mod tests { pipeline.set( store_name.clone(), store_data, - PreprocessAction::Image(ImageAction::ErrorIfDimensionsMismatch), + PreprocessAction::NoPreprocessing, ); pipeline.drop_pred_index( diff --git a/ahnlich/client/src/lib.rs b/ahnlich/client/src/lib.rs index 177fb189..d12c31c0 100644 --- a/ahnlich/client/src/lib.rs +++ b/ahnlich/client/src/lib.rs @@ -18,7 +18,7 @@ //! let db_client = DbClient::new_with_pool(pool); //! //! // Library has support for distributed tracing. https://www.w3.org/TR/trace-context/#traceparent-header -//! let tracing_id: Option = None, +//! let tracing_id: Option = None; //! db_client.ping(tracing_id).await.unwrap(); //! ``` //! @@ -31,7 +31,7 @@ //! let manager = AIConnManager::new("127.0.0.1".into(), 1369); //! let pool = Pool::builder(manager).max_size(10).build().unwrap(); //! // Library has support for distributed tracing - https://www.w3.org/TR/trace-context/#traceparent-header -//! let tracing_id: Option = None, +//! let tracing_id: Option = None; //! let ai_client = AIClient::new_with_pool(pool); //! ai_client.ping(tracing_id).await.unwrap(); //! ``` @@ -47,7 +47,7 @@ //! use ahnlich_client_rs::db::DbClient; //! //! let db_client = DbClient::new("127.0.0.1".into(), 1369).await.unwrap(); -//! let tracing_id: Option = None, +//! let tracing_id: Option = None; //! let mut pipeline = db_client.pipeline(3, tracing_id).unwrap(); //! pipeline.info_server(); //! pipeline.list_clients(); @@ -67,7 +67,7 @@ //! use std::collections::HashSet; //! //! let db_client = DbClient::new("127.0.0.1".into(), 1369).await.unwrap(); -//! let tracing_id: Option = None, +//! let tracing_id: Option = None; //! let mut pipeline = db_client.pipeline(1, tracing_id).unwrap(); //! pipeline.create_store( //! // StoreName found in prelude @@ -91,7 +91,7 @@ //! let query_model = AIModel::AllMiniLML6V2; //! // Model used to set to create embeddings for set command //! let index_model = AIModel::AllMiniLML6V2; -//! let tracing_id: Option = None, +//! let tracing_id: Option = None; //! let mut pipeline = ai_client.pipeline(2, tracing_id).unwrap(); //! pipeline.create_store( //! store_name.clone(), @@ -108,8 +108,8 @@ //! (StoreInput::RawString("Adidas Yeezy".into()), HashMap::new()), //! (StoreInput::RawString("Nike Air Jordans".into()),HashMap::new()), //! ], -//! PreprocessAction::RawString(StringAction::ErrorIfTokensExceed) -//! ) +//! PreprocessAction::NoPreprocessing +//! ); //! let results = pipeline.exec().await.unwrap(); //! ``` pub mod ai; diff --git a/ahnlich/dsl/src/ai.rs b/ahnlich/dsl/src/ai.rs index c15d875c..412a1c03 100644 --- a/ahnlich/dsl/src/ai.rs +++ b/ahnlich/dsl/src/ai.rs @@ -10,7 +10,7 @@ use crate::{ }, }; use ahnlich_types::{ - ai::{AIModel, AIQuery, ImageAction, PreprocessAction, StringAction}, + ai::{AIModel, AIQuery, PreprocessAction}, keyval::StoreName, metadata::MetadataKey, }; @@ -20,14 +20,8 @@ use crate::{error::DslError, predicate::parse_predicate_expression}; fn parse_to_preprocess_action(input: &str) -> PreprocessAction { match input.to_lowercase().trim() { - "erroriftokensexceed" => PreprocessAction::RawString(StringAction::ErrorIfTokensExceed), - "truncateiftokensexceed" => { - PreprocessAction::RawString(StringAction::TruncateIfTokensExceed) - } - "resizeimage" => PreprocessAction::Image(ImageAction::ResizeImage), - "errorifdimensionsmismatch" => { - PreprocessAction::Image(ImageAction::ErrorIfDimensionsMismatch) - } + "nopreprocessing" => PreprocessAction::NoPreprocessing, + "modelpreprocessing" => PreprocessAction::ModelPreprocessing, _ => panic!("Unexpected preprocess action"), } } diff --git a/ahnlich/dsl/src/syntax/syntax.pest b/ahnlich/dsl/src/syntax/syntax.pest index 863089a5..67af4580 100644 --- a/ahnlich/dsl/src/syntax/syntax.pest +++ b/ahnlich/dsl/src/syntax/syntax.pest @@ -94,13 +94,12 @@ ai_model = { ^"bge-base-en-v1.5"| ^"bge-large-en-v1.5"| ^"resnet-50"| - ^"clip-vit-b32" + ^"clip-vit-b32-image"| + ^"clip-vit-b32-text" } preprocess_action = { - ^"truncateiftokensexceed" | - ^"erroriftokensexceed" | - ^"resizeimage" | - ^"errorifdimensionsmismatch" + ^"nopreprocessing" | + ^"modelpreprocessing" } // Numbers diff --git a/ahnlich/dsl/src/tests/ai.rs b/ahnlich/dsl/src/tests/ai.rs index 24026b32..2f59f4f5 100644 --- a/ahnlich/dsl/src/tests/ai.rs +++ b/ahnlich/dsl/src/tests/ai.rs @@ -1,6 +1,6 @@ use crate::error::DslError; use ahnlich_types::{ - ai::{AIModel, AIQuery, PreprocessAction, StringAction}, + ai::{AIModel, AIQuery, PreprocessAction}, keyval::{StoreInput, StoreName}, metadata::MetadataKey, }; @@ -376,7 +376,7 @@ fn test_set_in_store_parse() { ]) ) ], - preprocess_action: PreprocessAction::RawString(StringAction::ErrorIfTokensExceed), + preprocess_action: PreprocessAction::NoPreprocessing, }] ); } diff --git a/ahnlich/types/src/ai/mod.rs b/ahnlich/types/src/ai/mod.rs index 71f3d6cb..53ff5f39 100644 --- a/ahnlich/types/src/ai/mod.rs +++ b/ahnlich/types/src/ai/mod.rs @@ -1,7 +1,7 @@ mod preprocess; mod query; mod server; -pub use preprocess::{ImageAction, PreprocessAction, StringAction}; +pub use preprocess::PreprocessAction; pub use query::{AIQuery, AIServerQuery}; use serde::{Deserialize, Serialize}; pub use server::{AIServerResponse, AIServerResult, AIStoreInfo}; diff --git a/ahnlich/types/src/ai/preprocess.rs b/ahnlich/types/src/ai/preprocess.rs index 248e8c3a..31ff8a78 100644 --- a/ahnlich/types/src/ai/preprocess.rs +++ b/ahnlich/types/src/ai/preprocess.rs @@ -1,36 +1,17 @@ use serde::{Deserialize, Serialize}; use std::fmt; -/// The String input has to be tokenized before saving into the model. -/// The action to be performed if the string input is too larger than the maximum tokens a -/// model can take. -#[derive(Copy, Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum StringAction { - TruncateIfTokensExceed, - ErrorIfTokensExceed, - ModelPreprocessing -} - -/// The action to be performed if the image dimensions is larger than the maximum size a -/// model can take. -#[derive(Copy, Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum ImageAction { - ResizeImage, - ErrorIfDimensionsMismatch, - ModelPreprocessing -} - #[derive(Copy, Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum PreprocessAction { - RawString(StringAction), - Image(ImageAction), + NoPreprocessing, + ModelPreprocessing } impl fmt::Display for PreprocessAction { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Self::RawString(_) => write!(f, "PreprocessString"), - Self::Image(_) => write!(f, "PreprocessImage"), + Self::NoPreprocessing => write!(f, "NoPreprocessing"), + Self::ModelPreprocessing => write!(f, "ModelPreprocessing"), } } } From b405579ba520765ca2f5708c934819d32cf5d8b5 Mon Sep 17 00:00:00 2001 From: HabeebShopeju Date: Mon, 25 Nov 2024 01:24:10 +0000 Subject: [PATCH 05/15] Postprocessor working on images --- ahnlich/ai/src/engine/ai/models.rs | 42 ++++---- ahnlich/ai/src/engine/ai/providers/mod.rs | 3 +- ahnlich/ai/src/engine/ai/providers/ort.rs | 90 ++++++++++------- .../processors/imagearray_to_ndarray.rs | 24 +++-- .../src/engine/ai/providers/processors/mod.rs | 16 ++- .../processors/onnx_output_transform.rs | 56 +++++++++++ .../ai/providers/processors/postprocessor.rs | 57 +++++++++-- ahnlich/ai/src/error.rs | 8 +- ahnlich/ai/src/manager/mod.rs | 98 +++++++++---------- ahnlich/typegen/src/tracers/query/ai.rs | 11 +-- 10 files changed, 271 insertions(+), 134 deletions(-) create mode 100644 ahnlich/ai/src/engine/ai/providers/processors/onnx_output_transform.rs diff --git a/ahnlich/ai/src/engine/ai/models.rs b/ahnlich/ai/src/engine/ai/models.rs index 2590c37a..67e31807 100644 --- a/ahnlich/ai/src/engine/ai/models.rs +++ b/ahnlich/ai/src/engine/ai/models.rs @@ -8,7 +8,7 @@ use ahnlich_types::{ keyval::{StoreInput, StoreKey}, }; use image::{DynamicImage, GenericImageView, ImageFormat, ImageReader}; -use ndarray::ArrayView; +use ndarray::{ArrayView, Ix4}; use ndarray::{Array, Ix3}; use nonzero_ext::nonzero; use serde::{Deserialize, Deserializer, Serialize, Serializer}; @@ -246,14 +246,15 @@ impl fmt::Display for InputAction { #[derive(Debug)] pub enum ModelInput { Texts(Vec), - Images(Vec), + Images(Array), } #[derive(Debug, Clone)] pub struct ImageArray { array: Array, image: DynamicImage, - image_format: ImageFormat + image_format: ImageFormat, + onnx_transformed: bool } impl ImageArray { @@ -288,13 +289,17 @@ impl ImageArray { .map_err(|_| AIProxyError::ImageBytesDecodeError)? .mapv(f32::from); - Ok(ImageArray { array, image, image_format: image_format.to_owned() }) + Ok(ImageArray { array, image, image_format: image_format.to_owned(), onnx_transformed: false }) } // Swapping axes from [rows, columns, channels] to [channels, rows, columns] for ONNX pub fn onnx_transform(&mut self) { + if self.onnx_transformed { + return; + } self.array.swap_axes(1, 2); self.array.swap_axes(0, 1); + self.onnx_transformed = true; } pub fn view(&self) -> ArrayView { @@ -324,7 +329,7 @@ impl ImageArray { let array = Array::from_shape_vec(shape, flattened_pixels) .map_err(|_| AIProxyError::ImageResizeError)? .mapv(f32::from); - Ok(ImageArray { array, image: resized_img, image_format: self.image_format }) + Ok(ImageArray { array, image: resized_img, image_format: self.image_format, onnx_transformed: false }) } pub fn crop(&self, x: u32, y: u32, width: u32, height: u32) -> Result { @@ -336,15 +341,21 @@ impl ImageArray { let array = Array::from_shape_vec(shape, flattened_pixels) .map_err(|_| AIProxyError::ImageCropError)? .mapv(f32::from); - Ok(ImageArray { array, image: cropped_img, image_format: self.image_format }) + Ok(ImageArray { array, image: cropped_img, image_format: self.image_format, onnx_transformed: false }) } pub fn image_dim(&self) -> (NonZeroUsize, NonZeroUsize) { let shape = self.array.shape(); - ( - NonZeroUsize::new(shape[1]).expect("Array columns should be non-zero"), - NonZeroUsize::new(shape[0]).expect("Array rows should be non-zero"), - ) // (width, height) + match self.onnx_transformed { + true => ( + NonZeroUsize::new(shape[2]).expect("Array columns should be non-zero"), + NonZeroUsize::new(shape[1]).expect("Array channels should be non-zero"), + ), // (width, channels) + false => ( + NonZeroUsize::new(shape[1]).expect("Array columns should be non-zero"), + NonZeroUsize::new(shape[0]).expect("Array rows should be non-zero"), + ) // (width, height) + } } } @@ -367,17 +378,6 @@ impl<'de> Deserialize<'de> for ImageArray { } } -// impl TryFrom for ModelInput { -// type Error = AIProxyError; -// -// fn try_from(value: StoreInput) -> Result { -// match value { -// StoreInput::RawString(s) => Ok(ModelInput::Text(s)), -// StoreInput::Image(bytes) => Ok(ModelInput::Image(ImageArray::try_new(bytes)?)), -// } -// } -// } - impl From<&ModelInput> for AIStoreInputType { fn from(value: &ModelInput) -> AIStoreInputType { match value { diff --git a/ahnlich/ai/src/engine/ai/providers/mod.rs b/ahnlich/ai/src/engine/ai/providers/mod.rs index 8aa78bd2..661900dc 100644 --- a/ahnlich/ai/src/engine/ai/providers/mod.rs +++ b/ahnlich/ai/src/engine/ai/providers/mod.rs @@ -1,6 +1,7 @@ pub(crate) mod ort; mod ort_helper; -mod processors; +pub mod processors; + use crate::cli::server::SupportedModels; use crate::engine::ai::models::{InputAction, ModelInput}; diff --git a/ahnlich/ai/src/engine/ai/providers/ort.rs b/ahnlich/ai/src/engine/ai/providers/ort.rs index f356d71a..29aad3b3 100644 --- a/ahnlich/ai/src/engine/ai/providers/ort.rs +++ b/ahnlich/ai/src/engine/ai/providers/ort.rs @@ -7,11 +7,11 @@ use fallible_collections::FallibleVec; use hf_hub::{api::sync::ApiBuilder, Cache}; use itertools::Itertools; use rayon::iter::Either; -use ort::{Session, Value}; +use ort::{Session, SessionOutputs, Value}; use rayon::prelude::*; use ahnlich_types::keyval::StoreKey; -use ndarray::{Array, Array1, Axis, Ix2, Ix3, Ix4, IxDyn, IxDynImpl}; +use ndarray::{Array, Array1, ArrayView, Axis, Ix2, Ix3, Ix4, IxDyn, IxDynImpl}; use std::convert::TryFrom; use std::default::Default; use std::fmt; @@ -22,7 +22,7 @@ use crate::engine::ai::providers::processors::preprocessor::{ImagePreprocessorFi use crate::engine::ai::providers::ort_helper::normalize; use ndarray::s; use tokenizers::Tokenizer; -use crate::engine::ai::providers::processors::postprocessor::{ORTPostprocessor, ORTTextPostprocessor}; +use crate::engine::ai::providers::processors::postprocessor::{ORTImagePostprocessor, ORTPostprocessor, ORTTextPostprocessor}; #[derive(Default)] pub struct ORTProvider { @@ -215,44 +215,46 @@ impl ORTProvider { } } - pub fn batch_inference_image(&self, inputs: Vec) -> Result, AIProxyError> { + pub fn postprocess_image_inference(&self, embeddings: SessionOutputs) -> Result, AIProxyError> { + match &self.postprocessor { + Some(ORTPostprocessor::Image(postprocessor)) => { + let output_data = postprocessor.process(embeddings) + .map_err( + |e| AIProxyError::ModelProviderPostprocessingError( + format!("Postprocessing failed for {:?} with error: {}", + self.supported_models.unwrap().to_string(), e) + ))?; + Ok(output_data) + } + _ => Err(AIProxyError::ModelPostprocessingError { + model_name: self.supported_models.unwrap().to_string(), + message: "Postprocessor not initialized".to_string(), + }) + } + } + + pub fn batch_inference_image(&self, inputs: Array) -> Result, AIProxyError> { let model = match &self.model { Some(ORTModel::Image(model)) => model, _ => return Err(AIProxyError::AIModelNotSupported { model_name: self.supported_models.unwrap().to_string() }), }; - let pixel_values_array = self.preprocess_images(inputs)?; match &model.session { Some(session) => { let session_inputs = ort::inputs![ model.input_params.first().expect("Hardcoded in parameters") - .as_str() => pixel_values_array.view(), + .as_str() => inputs.view(), ].map_err(|e| AIProxyError::ModelProviderPreprocessingError(e.to_string()))?; let outputs = session.run(session_inputs) .map_err(|e| AIProxyError::ModelProviderRunInferenceError(e.to_string()))?; - let last_hidden_state_key = match outputs.len() { - 1 => outputs.keys().next().unwrap(), - _ => model.output_param.as_str(), - }; - - let output_data = outputs[last_hidden_state_key] - .try_extract_tensor::() - .map_err(|e| AIProxyError::ModelProviderPostprocessingError(e.to_string()))?; - let store_keys = output_data - .axis_iter(Axis(0)) - .into_par_iter() - .map(|row| { - let embeddings = normalize(row.as_slice().unwrap()); - StoreKey(>::from(embeddings)) - }) - .collect(); - Ok(store_keys) + let embeddings = self.postprocess_image_inference(outputs)?; + Ok(embeddings) } None => Err(AIProxyError::AIModelNotInitialized) } } - pub fn batch_inference_text(&self, encodings: Vec) -> Result, AIProxyError> { + pub fn batch_inference_text(&self, encodings: Vec) -> Result, AIProxyError> { let model = match &self.model { Some(ORTModel::Text(model)) => model, _ => return Err(AIProxyError::AIModelNotSupported { model_name: self.supported_models.unwrap().to_string() }), @@ -338,13 +340,7 @@ impl ORTProvider { let session_output = session_output .to_owned(); let embeddings = self.postprocess_text_embeddings(session_output, attention_mask_array)?; - println!("Embeddings: {:?}", embeddings); - let store_keys = embeddings - .axis_iter(Axis(0)) - .into_par_iter() - .map(|embedding| StoreKey(>::from(embedding.to_owned()))) - .collect(); - Ok(store_keys) + Ok(embeddings.to_owned()) } None => Err(AIProxyError::AIModelNotInitialized), } @@ -407,8 +403,9 @@ impl ProviderTrait for ORTProvider { })); let mut preprocessor = ORTImagePreprocessor::default(); preprocessor.load(model_repo, preprocessor_files)?; - self.preprocessor = Some(ORTPreprocessor::Image(preprocessor) - ); + self.preprocessor = Some(ORTPreprocessor::Image(preprocessor)); + let postprocessor = ORTImagePostprocessor::load(supported_model)?; + self.postprocessor = Some(ORTPostprocessor::Image(postprocessor)); }, ORTModel::Text(ORTTextModel { weights_file, @@ -487,14 +484,35 @@ impl ProviderTrait for ORTProvider { ) -> Result, AIProxyError> { match input { - ModelInput::Images(images) => self.batch_inference_image(images), + ModelInput::Images(images) => { + let mut store_keys: Vec = FallibleVec::try_with_capacity(images.len())?; + + for batch_image in images.axis_chunks_iter(Axis(0), 16).into_iter() { + let embeddings = self.batch_inference_image(batch_image.to_owned())?; + let new_store_keys: Vec = embeddings + .axis_iter(Axis(0)) + .into_par_iter() + .map(|embedding| StoreKey(>::from(embedding.to_owned())) + ) + .collect(); + store_keys.extend(new_store_keys); + } + Ok(store_keys) + }, ModelInput::Texts(encodings) => { - let mut store_keys: Vec<_> = FallibleVec::try_with_capacity( + let mut store_keys: Vec = FallibleVec::try_with_capacity( encodings.len() )?; for batch_encoding in encodings.into_iter().chunks(16).into_iter() { - store_keys.extend(self.batch_inference_text(batch_encoding.collect())?); + let embeddings = self.batch_inference_text(batch_encoding.collect())?; + let new_store_keys: Vec = embeddings + .axis_iter(Axis(0)) + .into_par_iter() + .map(|embedding| StoreKey(>::from(embedding.to_owned())) + ) + .collect(); + store_keys.extend(new_store_keys); } Ok(store_keys) }, diff --git a/ahnlich/ai/src/engine/ai/providers/processors/imagearray_to_ndarray.rs b/ahnlich/ai/src/engine/ai/providers/processors/imagearray_to_ndarray.rs index 4ee46795..f562d9fc 100644 --- a/ahnlich/ai/src/engine/ai/providers/processors/imagearray_to_ndarray.rs +++ b/ahnlich/ai/src/engine/ai/providers/processors/imagearray_to_ndarray.rs @@ -1,4 +1,6 @@ +use image::image_dimensions; use ndarray::{ArrayView, Ix3}; +use std::sync::Mutex; use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator}; use crate::engine::ai::providers::processors::{Preprocessor, PreprocessorData}; use crate::error::AIProxyError; @@ -9,16 +11,22 @@ impl Preprocessor for ImageArrayToNdArray { fn process(&self, data: PreprocessorData) -> Result { match data { PreprocessorData::ImageArray(mut arrays) => { - let array_views: Vec> = arrays - .par_iter_mut() - .map(|image_arr| { - image_arr.onnx_transform(); - image_arr.view() - }) - .collect(); + let mut array_shapes = Mutex::new(vec![]); + let mut array_views = Mutex::new(vec![]); + arrays.par_iter_mut().for_each(|image_arr| { + image_arr.onnx_transform(); + array_shapes.lock().unwrap().push(image_arr.image_dim()); + array_views.lock().unwrap().push(image_arr.view()); + }); + + let array_shapes = array_shapes.into_inner().unwrap(); + let array_views = array_views.into_inner().unwrap(); let pixel_values_array = ndarray::stack(ndarray::Axis(0), &array_views) - .map_err(|e| AIProxyError::EmbeddingShapeError(e.to_string()))?; + .map_err(|e| AIProxyError::ImageArrayToNdArrayError { + message: format!("Images must have same dimensions, instead found: {:?}. \ + NB: Dimensions listed are not in same order as images provided.", array_shapes), + })?; Ok(PreprocessorData::NdArray3C(pixel_values_array)) } _ => Err(AIProxyError::ImageArrayToNdArrayError { diff --git a/ahnlich/ai/src/engine/ai/providers/processors/mod.rs b/ahnlich/ai/src/engine/ai/providers/processors/mod.rs index cb078275..e333fad1 100644 --- a/ahnlich/ai/src/engine/ai/providers/processors/mod.rs +++ b/ahnlich/ai/src/engine/ai/providers/processors/mod.rs @@ -1,6 +1,7 @@ use crate::engine::ai::models::ImageArray; use crate::error::AIProxyError; use ndarray::{Array, Ix2, Ix3, Ix4}; +use ort::SessionOutputs; use tokenizers::Encoding; pub mod normalize; @@ -12,6 +13,7 @@ pub mod preprocessor; pub mod tokenize; pub mod postprocessor; pub mod pooling; +mod onnx_output_transform; pub const CONV_NEXT_FEATURE_EXTRACTOR_CENTER_CROP_THRESHOLD: u32 = 384; @@ -30,7 +32,19 @@ pub enum PreprocessorData { EncodedText(Vec), } -pub enum PostprocessorData { +impl PreprocessorData { + pub fn into_ndarray3c(self) -> Result, AIProxyError> { + match self { + PreprocessorData::NdArray3C(array) => Ok(array), + _ => Err(AIProxyError::ModelProviderPreprocessingError( + "`into_ndarray3c` only works for PreprocessorData::NdArray3C".to_string() + )), + } + } +} + +pub enum PostprocessorData<'r, 's> { + OnnxOutput(SessionOutputs<'r, 's>), NdArray2(Array), NdArray3(Array) } diff --git a/ahnlich/ai/src/engine/ai/providers/processors/onnx_output_transform.rs b/ahnlich/ai/src/engine/ai/providers/processors/onnx_output_transform.rs new file mode 100644 index 00000000..6419bb70 --- /dev/null +++ b/ahnlich/ai/src/engine/ai/providers/processors/onnx_output_transform.rs @@ -0,0 +1,56 @@ +use ndarray::{Ix2, Ix3}; +use crate::engine::ai::providers::processors::{Postprocessor, PostprocessorData}; +use crate::error::AIProxyError; + + +pub struct OnnxOutputTransform { + output_key: String +} + +impl OnnxOutputTransform { + pub fn new(output_key: String) -> Self { + Self { output_key } + } +} + +impl Postprocessor for OnnxOutputTransform { + fn process(&self, data: PostprocessorData) -> Result { + match data { + PostprocessorData::OnnxOutput(onnx_output) => { + let output = onnx_output.get(self.output_key.as_str()) + .ok_or_else(|| AIProxyError::OnnxOutputTransformError { + message: format!("Output key '{}' not found in the OnnxOutput.", self.output_key), + })?; + let output = output.try_extract_tensor::().map_err( + |_| AIProxyError::OnnxOutputTransformError { + message: "Failed to extract tensor from OnnxOutput.".to_string(), + } + )?; + match output.ndim() { + 2 => { + let output = output.into_dimensionality::().map_err( + |_| AIProxyError::OnnxOutputTransformError { + message: "Failed to convert Dyn tensor to 2D array.".to_string(), + } + )?; + Ok(PostprocessorData::NdArray2(output.to_owned())) + }, + 3 => { + let output = output.into_dimensionality::().map_err( + |_| AIProxyError::OnnxOutputTransformError { + message: "Failed to convert Dyn tensor to 3D array.".to_string(), + } + )?; + Ok(PostprocessorData::NdArray3(output.to_owned())) + }, + _ => Err(AIProxyError::OnnxOutputTransformError { + message: "Only 2D and 3D tensors are supported.".to_string(), + }), + } + } + _ => Err(AIProxyError::OnnxOutputTransformError { + message: "Only OnnxOutput is supported".to_string(), + }), + } + } +} \ No newline at end of file diff --git a/ahnlich/ai/src/engine/ai/providers/processors/postprocessor.rs b/ahnlich/ai/src/engine/ai/providers/processors/postprocessor.rs index 2bf2d504..500fd5e0 100644 --- a/ahnlich/ai/src/engine/ai/providers/processors/postprocessor.rs +++ b/ahnlich/ai/src/engine/ai/providers/processors/postprocessor.rs @@ -1,9 +1,11 @@ use std::sync::{Arc, Mutex}; use ndarray::{Array, Ix2, Ix3}; +use ort::SessionOutputs; use crate::cli::server::SupportedModels; use crate::engine::ai::providers::processors::normalize::VectorNormalize; use crate::engine::ai::providers::processors::pooling::{Pooling, RegularPooling, MeanPooling}; use crate::engine::ai::providers::processors::{PostprocessorData, Postprocessor}; +use crate::engine::ai::providers::processors::onnx_output_transform::OnnxOutputTransform; use crate::error::AIProxyError; pub enum ORTPostprocessor { @@ -20,7 +22,7 @@ pub struct ORTTextPostprocessor { impl ORTTextPostprocessor { pub fn load(supported_model: SupportedModels) -> Result { - let artifacts = match supported_model { + let ops = match supported_model { SupportedModels::AllMiniLML6V2 | SupportedModels::AllMiniLML12V2 => Ok(( Pooling::Mean(MeanPooling::new()), @@ -42,8 +44,8 @@ impl ORTTextPostprocessor { }?; Ok(Self { model: supported_model, - pooling: Arc::new(Mutex::new(artifacts.0)), - normalize: artifacts.1 + pooling: Arc::new(Mutex::new(ops.0)), + normalize: ops.1 }) } @@ -69,13 +71,56 @@ impl ORTTextPostprocessor { PostprocessorData::NdArray2(array) => Ok(array), _ => Err(AIProxyError::ModelPostprocessingError { model_name: self.model.to_string(), - message: "Expected NdArray2, got NdArray3".to_string() + message: "Only returns NdArray2".to_string() }) } } } -struct ORTImagePostprocessor { +pub struct ORTImagePostprocessor { model: SupportedModels, - normalize: VectorNormalize + onnx_output_transform: OnnxOutputTransform, + normalize: Option +} + +impl ORTImagePostprocessor { + pub fn load(supported_model: SupportedModels) -> Result { + let output_transform = match supported_model { + SupportedModels::Resnet50 | + SupportedModels::ClipVitB32Image => + OnnxOutputTransform::new("image_embeds".to_string()), + _ => Err(AIProxyError::ModelPostprocessingError { + model_name: supported_model.to_string(), + message: "Unsupported model for ORTImagePostprocessor".to_string() + })? + }; + let normalize = match supported_model { + SupportedModels::Resnet50 => Ok(Some(VectorNormalize)), + SupportedModels::ClipVitB32Image => Ok(None), + _ => Err(AIProxyError::ModelPostprocessingError { + model_name: supported_model.to_string(), + message: "Unsupported model for ORTImagePostprocessor".to_string() + }) + }?; + Ok(Self { + model: supported_model, + normalize, + onnx_output_transform: output_transform + }) + } + + pub fn process(&self, session_outputs: SessionOutputs) -> Result, AIProxyError> { + let embeddings = self.onnx_output_transform.process(PostprocessorData::OnnxOutput(session_outputs))?; + let result = match &self.normalize { + Some(normalize) => normalize.process(embeddings), + None => Ok(embeddings) + }?; + match result { + PostprocessorData::NdArray2(array) => Ok(array), + _ => Err(AIProxyError::ModelPostprocessingError { + model_name: self.model.to_string(), + message: "Only returns NdArray2".to_string() + }) + } + } } \ No newline at end of file diff --git a/ahnlich/ai/src/error.rs b/ahnlich/ai/src/error.rs index 3c78234b..ef3ce57e 100644 --- a/ahnlich/ai/src/error.rs +++ b/ahnlich/ai/src/error.rs @@ -30,8 +30,7 @@ pub enum AIProxyError { index_model_type: AIStoreInputType, storeinput_type: AIStoreInputType, }, - #[error("Shape error {0}")] - EmbeddingShapeError(String), + #[error("Max Token Exceeded. Model Expects [{max_token_size}], input type was [{input_token_size}].")] TokenExceededError { max_token_size: usize, @@ -109,6 +108,11 @@ pub enum AIProxyError { message: String }, + #[error("Onnx output transform error: [{message}]")] + OnnxOutputTransformError { + message: String + }, + #[error("Rescale error: [{message}]")] RescaleError { message: String diff --git a/ahnlich/ai/src/manager/mod.rs b/ahnlich/ai/src/manager/mod.rs index b445ec12..e99e31c6 100644 --- a/ahnlich/ai/src/manager/mod.rs +++ b/ahnlich/ai/src/manager/mod.rs @@ -8,12 +8,16 @@ use crate::engine::ai::models::{ImageArray, InputAction}; /// channel use crate::engine::ai::models::{Model, ModelInput}; use crate::engine::ai::providers::ModelProviders; +use crate::engine::ai::providers::processors::{Preprocessor, PreprocessorData}; +use crate::engine::ai::providers::processors::imagearray_to_ndarray::ImageArrayToNdArray; + use crate::engine::ai::providers::ort::ORTProvider; use crate::error::AIProxyError; use ahnlich_types::ai::{AIModel, AIStoreInputType, PreprocessAction}; use ahnlich_types::keyval::{StoreInput, StoreKey}; use fallible_collections::FallibleVec; use moka::future::Cache; +use ndarray::{Array, Ix4}; use rayon::prelude::*; use tokenizers::Encoding; use task_manager::Task; @@ -86,20 +90,20 @@ impl ModelThread { })?; match sample { StoreInput::RawString(_) => { - let inputs_inner: Vec = inputs.into_par_iter().filter_map(|input| match input { + let inputs: Vec = inputs.into_par_iter().filter_map(|input| match input { StoreInput::RawString(string) => Some(string), _ => None, }).collect(); - let output = self.preprocess_raw_string(inputs_inner, process_action)?; + let output = self.preprocess_raw_string(inputs, process_action)?; Ok(ModelInput::Texts(output)) } StoreInput::Image(_) => { - // let inputs_inner: Vec> = inputs.into_par_iter().filter_map(|input| match input { - // StoreInput::Image(image_bytes) => Some(image_bytes), - // _ => None, - // }).collect(); - // let output = self.preprocess_image(inputs_inner, process_action)?; - Ok(ModelInput::Images(vec![])) + let inputs = inputs.into_par_iter().filter_map(|input| match input { + StoreInput::Image(image_bytes) => Some(ImageArray::try_new(image_bytes).ok()?), + _ => None, + }).collect(); + let output = self.preprocess_image(inputs, process_action)?; + Ok(ModelInput::Images(output)) } } } @@ -109,17 +113,10 @@ impl ModelThread { inputs: Vec, process_action: PreprocessAction, ) -> Result, AIProxyError> { - if self.model.input_type() != AIStoreInputType::RawString { - return Err(AIProxyError::ModelPreprocessingError { - model_name: self.model.model_name(), - message: "RawString preprocessing is not supported.".to_string(), - }); - } - let max_token_size = usize::from(self.model.max_input_token().ok_or_else(|| { - AIProxyError::AIModelInvalidOperation { + AIProxyError::ModelPreprocessingError { model_name: self.model.model_name(), - operation: "[max_input_token] function".to_string() + message: "RawString preprocessing is not supported.".to_string(), } })?); @@ -146,44 +143,45 @@ impl ModelThread { } } - #[tracing::instrument(skip(self, input))] + #[tracing::instrument(skip(self, inputs))] fn preprocess_image( &self, - input: ImageArray, + inputs: Vec, process_action: PreprocessAction - // input: Vec>, - ) -> Result { - Err(AIProxyError::ModelPreprocessingError { + ) -> Result, AIProxyError> { + // process image, return error if max dimensions exceeded + let (expected_width, expected_height) = self.model.expected_image_dimensions() + .ok_or(AIProxyError::ModelPreprocessingError { model_name: self.model.model_name(), message: "Image preprocessing is not supported.".to_string(), - }) - // process image, return error if max dimensions exceeded - // let dimensions = input.image_dim(); - // - // let Some((expected_width, expected_height)) = - // self.model.expected_image_dimensions() - // .ok_or( - // Err(AIProxyError::PreprocessingMismatchError { - // input_type: AIStoreInputType::Image, - // preprocess_action: process_action, - // }))?; - // - // match process_action { - // PreprocessAction::NoPreprocessing => { - // Ok(input) - // } - // PreprocessAction::ModelPreprocessing => { - // let (width, height) = dimensions; - // if width != expected_width || height != expected_height { - // Err(AIProxyError::ImageDimensionsMismatchError { - // image_dimensions: (width.into(), height.into()), - // expected_dimensions: (expected_width.into(), expected_height.into()), - // }) - // } else { - // Ok(input) - // } - // } - // } + })?; + let expected_width = usize::from(expected_width); + let expected_height = usize::from(expected_height); + + match &self.model.provider { + ModelProviders::ORT(provider) => { + let outputs = match process_action { + PreprocessAction::ModelPreprocessing => { + provider.preprocess_images(inputs)? + } + PreprocessAction::NoPreprocessing => { + ImageArrayToNdArray.process(PreprocessorData::ImageArray(inputs))? + .into_ndarray3c()? + } + }; + let outputs_shape = outputs.shape(); + let width = *outputs_shape.get(2).expect("Must exist"); + let height = *outputs_shape.get(3).expect("Must exist"); + if width != expected_width || height != expected_height { + return Err(AIProxyError::ImageDimensionsMismatchError { + image_dimensions: (width, height), + expected_dimensions: (expected_width.into(), expected_height.into()), + }); + } else { + return Ok(outputs); + } + } + } } } diff --git a/ahnlich/typegen/src/tracers/query/ai.rs b/ahnlich/typegen/src/tracers/query/ai.rs index 07a58925..993e4cd2 100644 --- a/ahnlich/typegen/src/tracers/query/ai.rs +++ b/ahnlich/typegen/src/tracers/query/ai.rs @@ -1,4 +1,4 @@ -use ahnlich_types::ai::{AIModel, AIStoreInputType, ImageAction, PreprocessAction, StringAction}; +use ahnlich_types::ai::{AIModel, AIStoreInputType, PreprocessAction}; use ahnlich_types::keyval::StoreInput; use ahnlich_types::predicate::Predicate; use ahnlich_types::predicate::PredicateCondition; @@ -81,7 +81,7 @@ pub fn trace_ai_query_enum() -> Registry { let set = AIQuery::Set { store: sample_store_name.clone(), - preprocess_action: PreprocessAction::Image(ImageAction::ErrorIfDimensionsMismatch), + preprocess_action: PreprocessAction::NoPreprocessing, inputs: vec![(test_search_input_bin.clone(), store_value)], }; @@ -170,13 +170,6 @@ pub fn trace_ai_query_enum() -> Registry { let _ = tracer .trace_type::(&samples) .expect("Error tracing AIModel"); - - let _ = tracer - .trace_type::(&samples) - .expect("Error tracing String action"); - let _ = tracer - .trace_type::(&samples) - .expect("Error tracing image action"); let _ = tracer .trace_type::(&samples) .inspect_err(|err| println!("Failed to parse type {}", err.explanation())) From 92cddd2b45c81748e5810180ec47a5dbdfd164ed Mon Sep 17 00:00:00 2001 From: HabeebShopeju Date: Tue, 26 Nov 2024 10:17:09 +0000 Subject: [PATCH 06/15] Preprocessors and Postprocessors now work --- ahnlich/ai/src/engine/ai/models.rs | 49 ++- ahnlich/ai/src/engine/ai/providers/mod.rs | 2 - ahnlich/ai/src/engine/ai/providers/ort.rs | 315 ++++++++---------- .../ai/src/engine/ai/providers/ort_helper.rs | 54 +-- .../ai/providers/processors/center_crop.rs | 151 +++++---- .../processors/imagearray_to_ndarray.rs | 39 ++- .../src/engine/ai/providers/processors/mod.rs | 18 +- .../ai/providers/processors/normalize.rs | 71 ++-- .../processors/onnx_output_transform.rs | 41 +-- .../engine/ai/providers/processors/pooling.rs | 37 +- .../ai/providers/processors/postprocessor.rs | 120 ++++--- .../ai/providers/processors/preprocessor.rs | 266 +++++++-------- .../engine/ai/providers/processors/rescale.rs | 28 +- .../engine/ai/providers/processors/resize.rs | 84 ++--- .../ai/providers/processors/tokenize.rs | 91 +++-- ahnlich/ai/src/error.rs | 56 +--- ahnlich/ai/src/manager/mod.rs | 86 ++--- ahnlich/ai/src/server/task.rs | 14 +- ahnlich/ai/src/tests/aiproxy_test.rs | 2 +- ahnlich/types/src/ai/preprocess.rs | 2 +- 20 files changed, 776 insertions(+), 750 deletions(-) diff --git a/ahnlich/ai/src/engine/ai/models.rs b/ahnlich/ai/src/engine/ai/models.rs index 67e31807..6075fba5 100644 --- a/ahnlich/ai/src/engine/ai/models.rs +++ b/ahnlich/ai/src/engine/ai/models.rs @@ -5,20 +5,20 @@ use crate::engine::ai::providers::ProviderTrait; use crate::error::AIProxyError; use ahnlich_types::{ ai::{AIModel, AIStoreInputType}, - keyval::{StoreInput, StoreKey}, + keyval::StoreKey, }; use image::{DynamicImage, GenericImageView, ImageFormat, ImageReader}; -use ndarray::{ArrayView, Ix4}; use ndarray::{Array, Ix3}; +use ndarray::{ArrayView, Ix4}; use nonzero_ext::nonzero; +use serde::de::Error as DeError; +use serde::ser::Error as SerError; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use std::fmt; use std::io::Cursor; use std::num::NonZeroUsize; use std::path::Path; use strum::Display; -use serde::ser::Error as SerError; -use serde::de::Error as DeError; use tokenizers::Encoding; #[derive(Display)] @@ -254,7 +254,7 @@ pub struct ImageArray { array: Array, image: DynamicImage, image_format: ImageFormat, - onnx_transformed: bool + onnx_transformed: bool, } impl ImageArray { @@ -289,7 +289,12 @@ impl ImageArray { .map_err(|_| AIProxyError::ImageBytesDecodeError)? .mapv(f32::from); - Ok(ImageArray { array, image, image_format: image_format.to_owned(), onnx_transformed: false }) + Ok(ImageArray { + array, + image, + image_format: image_format.to_owned(), + onnx_transformed: false, + }) } // Swapping axes from [rows, columns, channels] to [channels, rows, columns] for ONNX @@ -308,20 +313,22 @@ impl ImageArray { pub fn get_bytes(&self) -> Result, AIProxyError> { let mut buffer = Cursor::new(Vec::new()); - let _ = &self.image + let _ = &self + .image .write_to(&mut buffer, self.image_format) .map_err(|_| AIProxyError::ImageBytesEncodeError)?; let bytes = buffer.into_inner(); Ok(bytes) } - pub fn resize(&self, width: u32, height: u32, filter: Option) -> Result { + pub fn resize( + &self, + width: u32, + height: u32, + filter: Option, + ) -> Result { let filter_type = filter.unwrap_or(image::imageops::FilterType::CatmullRom); - let resized_img = self.image.resize_exact( - width, - height, - filter_type, - ); + let resized_img = self.image.resize_exact(width, height, filter_type); let channels = resized_img.color().channel_count(); let shape = (height as usize, width as usize, channels as usize); @@ -329,7 +336,12 @@ impl ImageArray { let array = Array::from_shape_vec(shape, flattened_pixels) .map_err(|_| AIProxyError::ImageResizeError)? .mapv(f32::from); - Ok(ImageArray { array, image: resized_img, image_format: self.image_format, onnx_transformed: false }) + Ok(ImageArray { + array, + image: resized_img, + image_format: self.image_format, + onnx_transformed: false, + }) } pub fn crop(&self, x: u32, y: u32, width: u32, height: u32) -> Result { @@ -341,7 +353,12 @@ impl ImageArray { let array = Array::from_shape_vec(shape, flattened_pixels) .map_err(|_| AIProxyError::ImageCropError)? .mapv(f32::from); - Ok(ImageArray { array, image: cropped_img, image_format: self.image_format, onnx_transformed: false }) + Ok(ImageArray { + array, + image: cropped_img, + image_format: self.image_format, + onnx_transformed: false, + }) } pub fn image_dim(&self) -> (NonZeroUsize, NonZeroUsize) { @@ -354,7 +371,7 @@ impl ImageArray { false => ( NonZeroUsize::new(shape[1]).expect("Array columns should be non-zero"), NonZeroUsize::new(shape[0]).expect("Array rows should be non-zero"), - ) // (width, height) + ), // (width, height) } } } diff --git a/ahnlich/ai/src/engine/ai/providers/mod.rs b/ahnlich/ai/src/engine/ai/providers/mod.rs index 661900dc..86adc61e 100644 --- a/ahnlich/ai/src/engine/ai/providers/mod.rs +++ b/ahnlich/ai/src/engine/ai/providers/mod.rs @@ -2,7 +2,6 @@ pub(crate) mod ort; mod ort_helper; pub mod processors; - use crate::cli::server::SupportedModels; use crate::engine::ai::models::{InputAction, ModelInput}; use crate::engine::ai::providers::ort::ORTProvider; @@ -11,7 +10,6 @@ use ahnlich_types::keyval::StoreKey; use std::path::Path; use strum::EnumIter; - #[derive(Debug, EnumIter)] pub enum ModelProviders { ORT(ORTProvider), diff --git a/ahnlich/ai/src/engine/ai/providers/ort.rs b/ahnlich/ai/src/engine/ai/providers/ort.rs index 29aad3b3..c2acdeb9 100644 --- a/ahnlich/ai/src/engine/ai/providers/ort.rs +++ b/ahnlich/ai/src/engine/ai/providers/ort.rs @@ -1,28 +1,27 @@ use crate::cli::server::SupportedModels; -use crate::engine::ai::models::{ImageArray, InputAction, Model, ModelInput}; -use crate::engine::ai::providers::processors::tokenize::Tokenize; +use crate::engine::ai::models::{ImageArray, InputAction, ModelInput}; use crate::engine::ai::providers::ProviderTrait; use crate::error::AIProxyError; use fallible_collections::FallibleVec; use hf_hub::{api::sync::ApiBuilder, Cache}; use itertools::Itertools; -use rayon::iter::Either; use ort::{Session, SessionOutputs, Value}; use rayon::prelude::*; +use crate::engine::ai::providers::processors::postprocessor::{ + ORTImagePostprocessor, ORTPostprocessor, ORTTextPostprocessor, +}; +use crate::engine::ai::providers::processors::preprocessor::{ + ORTImagePreprocessor, ORTPreprocessor, ORTTextPreprocessor, +}; use ahnlich_types::keyval::StoreKey; -use ndarray::{Array, Array1, ArrayView, Axis, Ix2, Ix3, Ix4, IxDyn, IxDynImpl}; +use ndarray::{Array, Array1, Axis, Ix2, Ix4}; use std::convert::TryFrom; use std::default::Default; use std::fmt; use std::path::{Path, PathBuf}; use std::thread::available_parallelism; use tokenizers::Encoding; -use crate::engine::ai::providers::processors::preprocessor::{ImagePreprocessorFiles, ORTImagePreprocessor, ORTPreprocessor, ORTTextPreprocessor, TextPreprocessorFiles}; -use crate::engine::ai::providers::ort_helper::normalize; -use ndarray::s; -use tokenizers::Tokenizer; -use crate::engine::ai::providers::processors::postprocessor::{ORTImagePostprocessor, ORTPostprocessor, ORTTextPostprocessor}; #[derive(Default)] pub struct ORTProvider { @@ -49,9 +48,6 @@ pub struct ORTImageModel { repo_name: String, weights_file: String, session: Option, - input_params: Vec, - output_param: String, - preprocessor_files: ImagePreprocessorFiles } #[derive(Default)] @@ -59,15 +55,13 @@ pub struct ORTTextModel { repo_name: String, weights_file: String, session: Option, - preprocessor_files: TextPreprocessorFiles } pub enum ORTModel { Image(ORTImageModel), - Text(ORTTextModel) + Text(ORTTextModel), } - impl TryFrom<&SupportedModels> for ORTModel { type Error = AIProxyError; @@ -76,15 +70,11 @@ impl TryFrom<&SupportedModels> for ORTModel { SupportedModels::Resnet50 => Ok(ORTModel::Image(ORTImageModel { repo_name: "Qdrant/resnet50-onnx".to_string(), weights_file: "model.onnx".to_string(), - input_params: vec!["input".to_string()], - output_param: "image_embeds".to_string(), ..Default::default() })), SupportedModels::ClipVitB32Image => Ok(ORTModel::Image(ORTImageModel { repo_name: "Qdrant/clip-ViT-B-32-vision".to_string(), weights_file: "model.onnx".to_string(), - input_params: vec!["pixel_values".to_string()], - output_param: "image_embeds".to_string(), ..Default::default() })), SupportedModels::ClipVitB32Text => Ok(ORTModel::Text(ORTTextModel { @@ -112,9 +102,6 @@ impl TryFrom<&SupportedModels> for ORTModel { weights_file: "onnx/model.onnx".to_string(), ..Default::default() })), - _ => Err(AIProxyError::AIModelNotSupported { - model_name: model.to_string(), - }), }; model_type @@ -133,131 +120,145 @@ impl ORTProvider { } } - fn get_postprocessor() -> Result<(), AIProxyError> { - Ok(()) - } - - pub fn preprocess_images(&self, data: Vec) -> Result, AIProxyError> { + pub fn preprocess_images( + &self, + data: Vec, + ) -> Result, AIProxyError> { match &self.preprocessor { Some(ORTPreprocessor::Image(preprocessor)) => { - let output_data = preprocessor.process(data) - .map_err( - |e| AIProxyError::ModelProviderPreprocessingError( - format!("Preprocessing failed for {:?} with error: {}", - self.supported_models.unwrap().to_string(), e) - ))?; + let output_data = preprocessor.process(data).map_err(|e| { + AIProxyError::ModelProviderPreprocessingError(format!( + "Preprocessing failed for {:?} with error: {}", + self.supported_models.unwrap().to_string(), + e + )) + })?; Ok(output_data) } - _ => Err(AIProxyError::AIModelNotInitialized) + _ => Err(AIProxyError::AIModelNotInitialized), } } - pub fn preprocess_texts(&self, data: Vec, truncate: bool) -> Result, AIProxyError> { + pub fn preprocess_texts( + &self, + data: Vec, + truncate: bool, + ) -> Result, AIProxyError> { match &self.preprocessor { Some(ORTPreprocessor::Text(preprocessor)) => { - let output_data = preprocessor.process(data, truncate) - .map_err( - |e| AIProxyError::ModelProviderPreprocessingError( - format!("Preprocessing failed for {:?} with error: {}", - self.supported_models.unwrap().to_string(), e) - ))?; + let output_data = preprocessor.process(data, truncate).map_err(|e| { + AIProxyError::ModelProviderPreprocessingError(format!( + "Preprocessing failed for {:?} with error: {}", + self.supported_models.unwrap().to_string(), + e + )) + })?; Ok(output_data) } _ => Err(AIProxyError::ModelPreprocessingError { model_name: self.supported_models.unwrap().to_string(), message: "Preprocessor not initialized".to_string(), - }) + }), } } - pub fn postprocess_text_embeddings(&self, embeddings: Array, attention_mask: Array) -> Result, AIProxyError> { - let embeddings = match embeddings.shape().len() { - 3 => { - let existing_shape = embeddings.shape().to_vec(); - Ok(embeddings.into_dimensionality() - .map_err( - |e| AIProxyError::ModelPostprocessingError { - model_name: self.supported_models.unwrap().to_string(), - message: format!("Unable to convert into 3D array. Existing shape {:?}. {:?}", existing_shape, e.to_string()) - })?.to_owned()) - } - 2 => { - let existing_shape = embeddings.shape().to_vec(); - let intermediate = embeddings.into_dimensionality() - .map_err( - |e| AIProxyError::ModelPostprocessingError { - model_name: self.supported_models.unwrap().to_string(), - message: format!("Unable to convert into 2D. Existing shape {:?}. {:?}", existing_shape, e.to_string()) - })?.to_owned(); - return Ok(intermediate) - } - _ => { - Err(AIProxyError::ModelPostprocessingError { - model_name: self.supported_models.unwrap().to_string(), - message: format!("Unsupported shape for postprocessing. Shape: {:?}", embeddings.shape()) - }) - } - }?; + pub fn postprocess_text_output( + &self, + session_output: SessionOutputs, + attention_mask: Array, + ) -> Result, AIProxyError> { match &self.postprocessor { Some(ORTPostprocessor::Text(postprocessor)) => { - let output_data = postprocessor.process(embeddings, attention_mask) - .map_err( - |e| AIProxyError::ModelProviderPostprocessingError( - format!("Postprocessing failed for {:?} with error: {}", - self.supported_models.unwrap().to_string(), e) - ))?; + let output_data = postprocessor + .process(session_output, attention_mask) + .map_err(|e| { + AIProxyError::ModelProviderPostprocessingError(format!( + "Postprocessing failed for {:?} with error: {}", + self.supported_models.unwrap().to_string(), + e + )) + })?; Ok(output_data) } _ => Err(AIProxyError::ModelPostprocessingError { model_name: self.supported_models.unwrap().to_string(), message: "Postprocessor not initialized".to_string(), - }) + }), } } - pub fn postprocess_image_inference(&self, embeddings: SessionOutputs) -> Result, AIProxyError> { + pub fn postprocess_image_output( + &self, + session_output: SessionOutputs, + ) -> Result, AIProxyError> { match &self.postprocessor { Some(ORTPostprocessor::Image(postprocessor)) => { - let output_data = postprocessor.process(embeddings) - .map_err( - |e| AIProxyError::ModelProviderPostprocessingError( - format!("Postprocessing failed for {:?} with error: {}", - self.supported_models.unwrap().to_string(), e) - ))?; + let output_data = postprocessor.process(session_output).map_err(|e| { + AIProxyError::ModelProviderPostprocessingError(format!( + "Postprocessing failed for {:?} with error: {}", + self.supported_models.unwrap().to_string(), + e + )) + })?; Ok(output_data) } _ => Err(AIProxyError::ModelPostprocessingError { model_name: self.supported_models.unwrap().to_string(), message: "Postprocessor not initialized".to_string(), - }) + }), } } - pub fn batch_inference_image(&self, inputs: Array) -> Result, AIProxyError> { + pub fn batch_inference_image( + &self, + inputs: Array, + ) -> Result, AIProxyError> { let model = match &self.model { Some(ORTModel::Image(model)) => model, - _ => return Err(AIProxyError::AIModelNotSupported { model_name: self.supported_models.unwrap().to_string() }), + _ => { + return Err(AIProxyError::AIModelNotSupported { + model_name: self.supported_models.unwrap().to_string(), + }) + } }; match &model.session { Some(session) => { + let input_param = match self.supported_models.unwrap() { + SupportedModels::Resnet50 => "input", + SupportedModels::ClipVitB32Image => "pixel_values", + _ => { + return Err(AIProxyError::AIModelNotSupported { + model_name: self.supported_models.unwrap().to_string(), + }) + } + }; + let session_inputs = ort::inputs![ - model.input_params.first().expect("Hardcoded in parameters") - .as_str() => inputs.view(), - ].map_err(|e| AIProxyError::ModelProviderPreprocessingError(e.to_string()))?; + input_param => inputs.view(), + ] + .map_err(|e| AIProxyError::ModelProviderPreprocessingError(e.to_string()))?; - let outputs = session.run(session_inputs) + let outputs = session + .run(session_inputs) .map_err(|e| AIProxyError::ModelProviderRunInferenceError(e.to_string()))?; - let embeddings = self.postprocess_image_inference(outputs)?; + let embeddings = self.postprocess_image_output(outputs)?; Ok(embeddings) } - None => Err(AIProxyError::AIModelNotInitialized) + None => Err(AIProxyError::AIModelNotInitialized), } } - pub fn batch_inference_text(&self, encodings: Vec) -> Result, AIProxyError> { + pub fn batch_inference_text( + &self, + encodings: Vec, + ) -> Result, AIProxyError> { let model = match &self.model { Some(ORTModel::Text(model)) => model, - _ => return Err(AIProxyError::AIModelNotSupported { model_name: self.supported_models.unwrap().to_string() }), + _ => { + return Err(AIProxyError::AIModelNotSupported { + model_name: self.supported_models.unwrap().to_string(), + }) + } }; let batch_size = encodings.len(); // Extract the encoding length and batch size @@ -287,59 +288,50 @@ impl ORTProvider { // Requires the closure to be FnMut ids_array.extend(ids.iter().map(|x| *x as i64)); mask_array.extend(mask.iter().map(|x| *x as i64)); - match token_type_ids_array { - Some(ref mut token_type_ids_array) => { - token_type_ids_array.extend(encoding.get_type_ids().iter().map(|x| *x as i64)); - } - None => {} + if let Some(ref mut token_type_ids_array) = token_type_ids_array { + token_type_ids_array.extend(encoding.get_type_ids().iter().map(|x| *x as i64)); } }); // Create CowArrays from vectors let inputs_ids_array = - Array::from_shape_vec((batch_size, encoding_length), ids_array) - .map_err(|e| { - AIProxyError::ModelProviderPreprocessingError(e.to_string()) - })?; + Array::from_shape_vec((batch_size, encoding_length), ids_array).map_err( + |e| AIProxyError::ModelProviderPreprocessingError(e.to_string()), + )?; let attention_mask_array = - Array::from_shape_vec((batch_size, encoding_length), mask_array).map_err(|e| { - AIProxyError::ModelProviderPreprocessingError(e.to_string()) - })?; + Array::from_shape_vec((batch_size, encoding_length), mask_array).map_err( + |e| AIProxyError::ModelProviderPreprocessingError(e.to_string()), + )?; let token_type_ids_array = match token_type_ids_array { - Some(token_type_ids_array) => { - Some(Array::from_shape_vec((batch_size, encoding_length), token_type_ids_array) + Some(token_type_ids_array) => Some( + Array::from_shape_vec((batch_size, encoding_length), token_type_ids_array) .map_err(|e| { AIProxyError::ModelProviderPreprocessingError(e.to_string()) - })?) - }, + })?, + ), None => None, }; let mut session_inputs = ort::inputs![ "input_ids" => Value::from_array(inputs_ids_array)?, "attention_mask" => Value::from_array(attention_mask_array.view())? - ].map_err(|e| AIProxyError::ModelProviderPreprocessingError(e.to_string()))?; - match token_type_ids_array { - Some(token_type_ids_array) => { - session_inputs.push(( - "token_type_ids".into(), - Value::from_array(token_type_ids_array)?.into(), - )); - } - None => {} + ] + .map_err(|e| AIProxyError::ModelProviderPreprocessingError(e.to_string()))?; + + if let Some(token_type_ids_array) = token_type_ids_array { + session_inputs.push(( + "token_type_ids".into(), + Value::from_array(token_type_ids_array)?.into(), + )); } - let output_key = session.outputs.first().expect("Must exist").name.clone(); - let session_outputs = session.run(session_inputs) + let session_outputs = session + .run(session_inputs) .map_err(|e| AIProxyError::ModelProviderRunInferenceError(e.to_string()))?; - let session_output = session_outputs[output_key.as_str()] - .try_extract_tensor::() - .map_err(|e| AIProxyError::ModelProviderPostprocessingError(e.to_string()))?; - let session_output = session_output - .to_owned(); - let embeddings = self.postprocess_text_embeddings(session_output, attention_mask_array)?; + let embeddings = + self.postprocess_text_output(session_outputs, attention_mask_array)?; Ok(embeddings.to_owned()) } None => Err(AIProxyError::AIModelNotInitialized), @@ -347,7 +339,6 @@ impl ORTProvider { } } - impl ProviderTrait for ORTProvider { fn set_cache_location(&mut self, location: &Path) { self.cache_location = Some(location.join(self.cache_location_extension.clone())); @@ -380,9 +371,6 @@ impl ProviderTrait for ORTProvider { ORTModel::Image(ORTImageModel { weights_file, repo_name, - input_params: input_param, - output_param, - preprocessor_files, .. }) => { let model_repo = api.model(repo_name.clone()); @@ -395,22 +383,17 @@ impl ProviderTrait for ORTProvider { self.model = Some(ORTModel::Image(ORTImageModel { repo_name, weights_file, - input_params: input_param, - output_param, - session: Some(session), - preprocessor_files: preprocessor_files.clone(), - ..Default::default() + session: Some(session) })); - let mut preprocessor = ORTImagePreprocessor::default(); - preprocessor.load(model_repo, preprocessor_files)?; + let preprocessor = + ORTImagePreprocessor::load(self.supported_models.unwrap(), model_repo)?; self.preprocessor = Some(ORTPreprocessor::Image(preprocessor)); let postprocessor = ORTImagePostprocessor::load(supported_model)?; self.postprocessor = Some(ORTPostprocessor::Image(postprocessor)); - }, + } ORTModel::Text(ORTTextModel { weights_file, repo_name, - preprocessor_files, .. }) => { let model_repo = api.model(repo_name.clone()); @@ -424,9 +407,9 @@ impl ProviderTrait for ORTProvider { repo_name, weights_file, session: Some(session), - preprocessor_files: preprocessor_files.clone(), })); - let preprocessor = ORTTextPreprocessor::load(model_repo, preprocessor_files)?; + let preprocessor = + ORTTextPreprocessor::load(self.supported_models.unwrap(), model_repo)?; self.preprocessor = Some(ORTPreprocessor::Text(preprocessor)); let postprocessor = ORTTextPostprocessor::load(supported_model)?; self.postprocessor = Some(ORTPostprocessor::Text(postprocessor)); @@ -450,31 +433,23 @@ impl ProviderTrait for ORTProvider { .build() .map_err(|e| AIProxyError::APIBuilderError(e.to_string()))?; - match ort_model { + let (repo_name, weights_file) = match ort_model { ORTModel::Image(ORTImageModel { repo_name, weights_file, .. - }) => { - let model_repo = api.model(repo_name); - model_repo - .get(&weights_file) - .map_err(|e| AIProxyError::APIBuilderError(e.to_string()))?; - model_repo - .get("preprocessor_config.json") - .map_err(|e| AIProxyError::APIBuilderError(e.to_string()))?; - Ok(()) - }, + }) => (repo_name, weights_file), ORTModel::Text(ORTTextModel { repo_name, - preprocessor_files, + weights_file, .. - }) => { - let model_repo = api.model(repo_name.clone()); - Tokenize::download_artifacts(preprocessor_files.tokenize, model_repo)?; - Ok(()) - } - } + }) => (repo_name, weights_file), + }; + let model_repo = api.model(repo_name); + model_repo + .get(&weights_file) + .map_err(|e| AIProxyError::APIBuilderError(e.to_string()))?; + Ok(()) } fn run_inference( @@ -482,40 +457,36 @@ impl ProviderTrait for ORTProvider { input: ModelInput, _action_type: &InputAction, ) -> Result, AIProxyError> { - match input { ModelInput::Images(images) => { let mut store_keys: Vec = FallibleVec::try_with_capacity(images.len())?; - for batch_image in images.axis_chunks_iter(Axis(0), 16).into_iter() { + for batch_image in images.axis_chunks_iter(Axis(0), 16) { let embeddings = self.batch_inference_image(batch_image.to_owned())?; let new_store_keys: Vec = embeddings .axis_iter(Axis(0)) .into_par_iter() - .map(|embedding| StoreKey(>::from(embedding.to_owned())) - ) + .map(|embedding| StoreKey(>::from(embedding.to_owned()))) .collect(); store_keys.extend(new_store_keys); } Ok(store_keys) - }, + } ModelInput::Texts(encodings) => { - let mut store_keys: Vec = FallibleVec::try_with_capacity( - encodings.len() - )?; + let mut store_keys: Vec = + FallibleVec::try_with_capacity(encodings.len())?; for batch_encoding in encodings.into_iter().chunks(16).into_iter() { let embeddings = self.batch_inference_text(batch_encoding.collect())?; let new_store_keys: Vec = embeddings .axis_iter(Axis(0)) .into_par_iter() - .map(|embedding| StoreKey(>::from(embedding.to_owned())) - ) + .map(|embedding| StoreKey(>::from(embedding.to_owned()))) .collect(); store_keys.extend(new_store_keys); } Ok(store_keys) - }, + } } } } diff --git a/ahnlich/ai/src/engine/ai/providers/ort_helper.rs b/ahnlich/ai/src/engine/ai/providers/ort_helper.rs index b9cdaba9..d41d94ea 100644 --- a/ahnlich/ai/src/engine/ai/providers/ort_helper.rs +++ b/ahnlich/ai/src/engine/ai/providers/ort_helper.rs @@ -5,7 +5,6 @@ use std::collections::HashMap; use std::fs::File; use std::io::Read; use std::path::PathBuf; -use rayon::prelude::*; /// Public function to read a file to bytes. /// To be used when loading local model files. @@ -13,19 +12,23 @@ pub fn read_file_to_bytes(file: &PathBuf) -> Result, AIProxyError> { let mut file = File::open(file).map_err(|_| AIProxyError::ModelConfigLoadError { message: format!("failed to open file {:?}", file), })?; - let file_size = file.metadata().map_err(|_| AIProxyError::ModelConfigLoadError { - message: format!("failed to get metadata for file {:?}", file), - })?.len() as usize; + let file_size = file + .metadata() + .map_err(|_| AIProxyError::ModelConfigLoadError { + message: format!("failed to get metadata for file {:?}", file), + })? + .len() as usize; let mut buffer = Vec::with_capacity(file_size); - file.read_to_end(&mut buffer).map_err(|_| AIProxyError::ModelConfigLoadError { - message: format!("failed to read file {:?}", file), - })?; + file.read_to_end(&mut buffer) + .map_err(|_| AIProxyError::ModelConfigLoadError { + message: format!("failed to read file {:?}", file), + })?; Ok(buffer) } pub struct HFConfigReader { model_repo: ApiRepo, - cache: HashMap> + cache: HashMap>, } impl HFConfigReader { @@ -40,25 +43,22 @@ impl HFConfigReader { if let Some(value) = self.cache.get(config_name) { return value.clone(); } - let file = self.model_repo.get(config_name).map_err(|e| AIProxyError::ModelConfigLoadError{ - message: format!("failed to fetch {}, {}", config_name, e.to_string()), - })?; - let contents = read_file_to_bytes(&file).map_err(|e| AIProxyError::ModelConfigLoadError{ - message: format!("failed to read {}, {}", config_name, e.to_string()), - })?; - let value: serde_json::Value = serde_json::from_slice(&contents).map_err( - |e| AIProxyError::ModelConfigLoadError{ - message: format!("failed to parse {}, {}", config_name, e.to_string()), - })?; - self.cache.insert(config_name.to_string(), Ok(value.clone())); + let file = + self.model_repo + .get(config_name) + .map_err(|e| AIProxyError::ModelConfigLoadError { + message: format!("failed to fetch {}, {}", config_name, e), + })?; + let contents = + read_file_to_bytes(&file).map_err(|e| AIProxyError::ModelConfigLoadError { + message: format!("failed to read {}, {}", config_name, e), + })?; + let value: serde_json::Value = + serde_json::from_slice(&contents).map_err(|e| AIProxyError::ModelConfigLoadError { + message: format!("failed to parse {}, {}", config_name, e), + })?; + self.cache + .insert(config_name.to_string(), Ok(value.clone())); Ok(value) } } - -pub fn normalize(v: &[f32]) -> Vec { - let norm = (v.par_iter().map(|val| val * val).sum::()).sqrt(); - let epsilon = 1e-12; - - // We add the super-small epsilon to avoid dividing by zero - v.par_iter().map(|&val| val / (norm + epsilon)).collect() -} \ No newline at end of file diff --git a/ahnlich/ai/src/engine/ai/providers/processors/center_crop.rs b/ahnlich/ai/src/engine/ai/providers/processors/center_crop.rs index 4ffc6a15..ee2e331d 100644 --- a/ahnlich/ai/src/engine/ai/providers/processors/center_crop.rs +++ b/ahnlich/ai/src/engine/ai/providers/processors/center_crop.rs @@ -1,27 +1,22 @@ -use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use crate::engine::ai::models::ImageArray; -use crate::engine::ai::providers::processors::{CONV_NEXT_FEATURE_EXTRACTOR_CENTER_CROP_THRESHOLD, Preprocessor, PreprocessorData}; +use crate::engine::ai::providers::processors::{ + Preprocessor, PreprocessorData, CONV_NEXT_FEATURE_EXTRACTOR_CENTER_CROP_THRESHOLD, +}; use crate::error::AIProxyError; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; pub struct CenterCrop { crop_size: (u32, u32), // (width, height) - process: bool } -impl TryFrom<&serde_json::Value> for CenterCrop { - type Error = AIProxyError; - - fn try_from(config: &serde_json::Value) -> Result { +impl CenterCrop { + pub fn initialize(config: &serde_json::Value) -> Result, AIProxyError> { if !config["do_center_crop"].as_bool().unwrap_or(false) { - return Ok( - Self { - crop_size: (0, 0), - process: false - } - ); + return Ok(None); } - let image_processor_type = config["image_processor_type"].as_str() + let image_processor_type = config["image_processor_type"] + .as_str() .unwrap_or("CLIPImageProcessor"); match image_processor_type { @@ -37,25 +32,35 @@ impl TryFrom<&serde_json::Value> for CenterCrop { } let (width, height); if crop_size.is_object() { - height = crop_size["height"].as_u64().ok_or_else(|| AIProxyError::ModelConfigLoadError { - message: "The key 'height' is missing from the ['crop_size'] section of \ - the configuration or has the wrong type; it should be an integer".to_string(), + height = crop_size["height"].as_u64().ok_or_else(|| { + AIProxyError::ModelConfigLoadError { + message: + "The key 'height' is missing from the ['crop_size'] section of \ + the configuration or has the wrong type; it should be an integer" + .to_string(), + } })? as u32; - width = crop_size["width"].as_u64().ok_or_else(|| AIProxyError::ModelConfigLoadError { - message: "The key 'width' is missing from the ['crop_size'] section of \ - the configuration or has the wrong type; it should be an integer".to_string(), + width = crop_size["width"].as_u64().ok_or_else(|| { + AIProxyError::ModelConfigLoadError { + message: + "The key 'width' is missing from the ['crop_size'] section of \ + the configuration or has the wrong type; it should be an integer" + .to_string(), + } })? as u32; } else { - let size = crop_size.as_u64().expect("It will always be an integer here.") as u32; + let size = crop_size + .as_u64() + .expect("It will always be an integer here.") + as u32; width = size; height = size; } - Ok(Self { + Ok(Some(Self { crop_size: (width, height), - process: true - }) - }, + })) + } "ConvNextFeatureExtractor" => { let size = &config["size"]; if !size.is_object() { @@ -63,20 +68,31 @@ impl TryFrom<&serde_json::Value> for CenterCrop { message: "The key 'size' is missing from the configuration or has the wrong type; it should be an object containing a 'shortest_edge' mapping.".to_string(), }); } - let shortest_edge = size["shortest_edge"].as_u64() - .ok_or_else(|| AIProxyError::ModelConfigLoadError { + let shortest_edge = size["shortest_edge"].as_u64().ok_or_else(|| { + AIProxyError::ModelConfigLoadError { message: "The key 'shortest_edge' is missing from the ['size'] section of \ - the configuration or has the wrong type; it should be an integer".to_string(), - })? as u32; - Ok(Self { - crop_size: (shortest_edge, shortest_edge), - process: shortest_edge < CONV_NEXT_FEATURE_EXTRACTOR_CENTER_CROP_THRESHOLD - }) - }, + the configuration or has the wrong type; it should be an integer" + .to_string(), + } + })? as u32; + + let should_center_crop = + shortest_edge < CONV_NEXT_FEATURE_EXTRACTOR_CENTER_CROP_THRESHOLD; + match should_center_crop { + true => Ok(Some(Self { + crop_size: (shortest_edge, shortest_edge), + })), + false => Ok(None), + } + } _ => Err(AIProxyError::ModelConfigLoadError { - message: format!("The key 'image_processor_type' in the configuration has the wrong value: {}; \ - it should be either 'CLIPImageProcessor' or 'ConvNextFeatureExtractor'.", image_processor_type).to_string(), - }) + message: format!( + "The key 'image_processor_type' in the configuration has the wrong value: {}; \ + it should be either 'CLIPImageProcessor' or 'ConvNextFeatureExtractor'.", + image_processor_type + ) + .to_string(), + }), } } } @@ -85,38 +101,37 @@ impl Preprocessor for CenterCrop { fn process(&self, data: PreprocessorData) -> Result { match data { PreprocessorData::ImageArray(image_array) => { - let processed = image_array.par_iter().map(|image| { - if !self.process { - return Ok(image.clone()); - } - - let (width, height) = image.image_dim(); - let width = width.get() as u32; - let height = height.get() as u32; - let (crop_width, crop_height) = self.crop_size; - if crop_width == width && crop_height == height { - let image = image.to_owned(); - Ok(image) - } else if crop_width <= width || crop_height <= height { - let x = (width - crop_width) / 2; - let y = (height - crop_height) / 2; - let image = image.crop(x, y, crop_width, crop_height)?; - Ok(image) - } else { - // The Fastembed-rs implementation pads the image with zeros, but that does not make - // sense to me (HAKSOAT), just as it does not make sense to "crop" to a bigger size. - // This is why I am going with resize, it is also important to note that - // I expect these cases to be minor because Resize will often be called before Center Crop anyway. - let image = image.resize(crop_width, crop_height, None)?; - Ok(image) - } - }) - .collect::, AIProxyError>>(); + let processed = image_array + .par_iter() + .map(|image| { + let (width, height) = image.image_dim(); + let width = width.get() as u32; + let height = height.get() as u32; + let (crop_width, crop_height) = self.crop_size; + if crop_width == width && crop_height == height { + let image = image.to_owned(); + Ok(image) + } else if crop_width <= width || crop_height <= height { + let x = (width - crop_width) / 2; + let y = (height - crop_height) / 2; + let image = image.crop(x, y, crop_width, crop_height)?; + Ok(image) + } else { + // The Fastembed-rs implementation pads the image with zeros, but that does not make + // sense to me (HAKSOAT), just as it does not make sense to "crop" to a bigger size. + // This is why I am going with resize, it is also important to note that + // I expect these cases to be minor because Resize will often be called before Center Crop anyway. + let image = image.resize(crop_width, crop_height, None)?; + Ok(image) + } + }) + .collect::, AIProxyError>>(); Ok(PreprocessorData::ImageArray(processed?)) - }, + } _ => Err(AIProxyError::CenterCropError { - message: "CenterCrop process failed. Expected ImageArray, got NdArray3C".to_string(), - }) + message: "CenterCrop process failed. Expected ImageArray, got NdArray3C" + .to_string(), + }), } } -} \ No newline at end of file +} diff --git a/ahnlich/ai/src/engine/ai/providers/processors/imagearray_to_ndarray.rs b/ahnlich/ai/src/engine/ai/providers/processors/imagearray_to_ndarray.rs index f562d9fc..14bcc092 100644 --- a/ahnlich/ai/src/engine/ai/providers/processors/imagearray_to_ndarray.rs +++ b/ahnlich/ai/src/engine/ai/providers/processors/imagearray_to_ndarray.rs @@ -1,9 +1,6 @@ -use image::image_dimensions; -use ndarray::{ArrayView, Ix3}; -use std::sync::Mutex; -use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator}; use crate::engine::ai::providers::processors::{Preprocessor, PreprocessorData}; use crate::error::AIProxyError; +use std::sync::Mutex; pub struct ImageArrayToNdArray; @@ -11,26 +8,32 @@ impl Preprocessor for ImageArrayToNdArray { fn process(&self, data: PreprocessorData) -> Result { match data { PreprocessorData::ImageArray(mut arrays) => { - let mut array_shapes = Mutex::new(vec![]); - let mut array_views = Mutex::new(vec![]); - arrays.par_iter_mut().for_each(|image_arr| { - image_arr.onnx_transform(); - array_shapes.lock().unwrap().push(image_arr.image_dim()); - array_views.lock().unwrap().push(image_arr.view()); - }); + let array_shapes = Mutex::new(vec![]); + // Not using par_iter_mut here because it messes up the order of the images + let array_views = arrays + .iter_mut() + .map(|image_arr| { + image_arr.onnx_transform(); + array_shapes.lock().unwrap().push(image_arr.image_dim()); + image_arr.view() + }) + .collect::>(); let array_shapes = array_shapes.into_inner().unwrap(); - let array_views = array_views.into_inner().unwrap(); - - let pixel_values_array = ndarray::stack(ndarray::Axis(0), &array_views) - .map_err(|e| AIProxyError::ImageArrayToNdArrayError { - message: format!("Images must have same dimensions, instead found: {:?}. \ - NB: Dimensions listed are not in same order as images provided.", array_shapes), + let pixel_values_array = + ndarray::stack(ndarray::Axis(0), &array_views).map_err(|_| { + AIProxyError::ImageArrayToNdArrayError { + message: format!( + "Images must have same dimensions, instead found: {:?}.", + array_shapes + ), + } })?; Ok(PreprocessorData::NdArray3C(pixel_values_array)) } _ => Err(AIProxyError::ImageArrayToNdArrayError { - message: "ImageArrayToNdArray failed. Expected ImageArray, got NdArray3C".to_string(), + message: "ImageArrayToNdArray failed. Expected ImageArray, got NdArray3C" + .to_string(), }), } } diff --git a/ahnlich/ai/src/engine/ai/providers/processors/mod.rs b/ahnlich/ai/src/engine/ai/providers/processors/mod.rs index e333fad1..08c2ab20 100644 --- a/ahnlich/ai/src/engine/ai/providers/processors/mod.rs +++ b/ahnlich/ai/src/engine/ai/providers/processors/mod.rs @@ -4,16 +4,16 @@ use ndarray::{Array, Ix2, Ix3, Ix4}; use ort::SessionOutputs; use tokenizers::Encoding; -pub mod normalize; -pub mod resize; -pub mod imagearray_to_ndarray; pub mod center_crop; -pub mod rescale; +pub mod imagearray_to_ndarray; +pub mod normalize; +mod onnx_output_transform; +pub mod pooling; +pub mod postprocessor; pub mod preprocessor; +pub mod rescale; +pub mod resize; pub mod tokenize; -pub mod postprocessor; -pub mod pooling; -mod onnx_output_transform; pub const CONV_NEXT_FEATURE_EXTRACTOR_CENTER_CROP_THRESHOLD: u32 = 384; @@ -37,7 +37,7 @@ impl PreprocessorData { match self { PreprocessorData::NdArray3C(array) => Ok(array), _ => Err(AIProxyError::ModelProviderPreprocessingError( - "`into_ndarray3c` only works for PreprocessorData::NdArray3C".to_string() + "`into_ndarray3c` only works for PreprocessorData::NdArray3C".to_string(), )), } } @@ -46,5 +46,5 @@ impl PreprocessorData { pub enum PostprocessorData<'r, 's> { OnnxOutput(SessionOutputs<'r, 's>), NdArray2(Array), - NdArray3(Array) + NdArray3(Array), } diff --git a/ahnlich/ai/src/engine/ai/providers/processors/normalize.rs b/ahnlich/ai/src/engine/ai/providers/processors/normalize.rs index 8ef87565..543dc640 100644 --- a/ahnlich/ai/src/engine/ai/providers/processors/normalize.rs +++ b/ahnlich/ai/src/engine/ai/providers/processors/normalize.rs @@ -1,50 +1,45 @@ +use crate::engine::ai::providers::processors::{ + Postprocessor, PostprocessorData, Preprocessor, PreprocessorData, +}; use crate::error::AIProxyError; -use crate::engine::ai::providers::processors::{Postprocessor, PostprocessorData, Preprocessor, PreprocessorData}; use ndarray::{Array, Axis}; use std::ops::{Div, Sub}; pub struct ImageNormalize { mean: Vec, std: Vec, - process: bool } -impl TryFrom<&serde_json::Value> for ImageNormalize { - type Error = AIProxyError; - - fn try_from(config: &serde_json::Value) -> Result { +impl ImageNormalize { + pub fn initialize(config: &serde_json::Value) -> Result, AIProxyError> { if !config["do_normalize"].as_bool().unwrap_or(false) { - return Ok( - Self { - mean: vec![], - std: vec![], - process: false - } - ); + return Ok(None); } fn get_array(value: &serde_json::Value, key: &str) -> Result, AIProxyError> { - let field = value.get(key) + let field = value + .get(key) .ok_or_else(|| AIProxyError::ModelConfigLoadError { message: format!("The key '{}' is missing from the configuration.", key), })?; - serde_json::from_value(field.to_owned()).map_err(|_| AIProxyError::ModelConfigLoadError { - message: format!("The key '{}' in the configuration must be an array of floats.", key), + serde_json::from_value(field.to_owned()).map_err(|_| { + AIProxyError::ModelConfigLoadError { + message: format!( + "The key '{}' in the configuration must be an array of floats.", + key + ), + } }) } let mean = get_array(config, "image_mean")?; let std = get_array(config, "image_std")?; - Ok(Self { mean, std, process: true }) + Ok(Some(Self { mean, std })) } } impl Preprocessor for ImageNormalize { fn process(&self, data: PreprocessorData) -> Result { - if !self.process { - return Ok(data); - } - match data { PreprocessorData::NdArray3C(array) => { let mean = Array::from_vec(self.mean.clone()) @@ -57,16 +52,21 @@ impl Preprocessor for ImageNormalize { let shape = array.shape().to_vec(); match shape.as_slice() { [b, c, h, w] => { - let mean_broadcast = mean.broadcast((*b, *c, *h, *w)).expect("Broadcast will always succeed."); - let std_broadcast = std.broadcast((*b, *c, *h, *w)).expect("Broadcast will always succeed."); - let array_normalized = array - .sub(mean_broadcast) - .div(std_broadcast); + let mean_broadcast = mean + .broadcast((*b, *c, *h, *w)) + .expect("Broadcast will always succeed."); + let std_broadcast = std + .broadcast((*b, *c, *h, *w)) + .expect("Broadcast will always succeed."); + let array_normalized = array.sub(mean_broadcast).div(std_broadcast); Ok(PreprocessorData::NdArray3C(array_normalized)) } _ => Err(AIProxyError::ImageNormalizationError { - message: format!("Image normalization failed due to invalid shape for image array; \ - expected 4 dimensions, got {} dimensions.", shape.len()), + message: format!( + "Image normalization failed due to invalid shape for image array; \ + expected 4 dimensions, got {} dimensions.", + shape.len() + ), }), } } @@ -90,10 +90,15 @@ impl Postprocessor for VectorNormalize { let source_shape = regularized_norm.shape(); let target_shape = array.shape(); let broadcasted_norm = regularized_norm - .broadcast(array.dim()).ok_or(AIProxyError::VectorNormalizationError { - message: format!("Could not broadcast attention mask with shape {:?} to \ - shape {:?} of the input tensor.", source_shape, target_shape), - })?.to_owned(); + .broadcast(array.dim()) + .ok_or(AIProxyError::VectorNormalizationError { + message: format!( + "Could not broadcast attention mask with shape {:?} to \ + shape {:?} of the input tensor.", + source_shape, target_shape + ), + })? + .to_owned(); Ok(PostprocessorData::NdArray2(array / broadcasted_norm)) } _ => Err(AIProxyError::VectorNormalizationError { @@ -101,4 +106,4 @@ impl Postprocessor for VectorNormalize { }), } } -} \ No newline at end of file +} diff --git a/ahnlich/ai/src/engine/ai/providers/processors/onnx_output_transform.rs b/ahnlich/ai/src/engine/ai/providers/processors/onnx_output_transform.rs index 6419bb70..954d1399 100644 --- a/ahnlich/ai/src/engine/ai/providers/processors/onnx_output_transform.rs +++ b/ahnlich/ai/src/engine/ai/providers/processors/onnx_output_transform.rs @@ -1,10 +1,9 @@ -use ndarray::{Ix2, Ix3}; use crate::engine::ai::providers::processors::{Postprocessor, PostprocessorData}; use crate::error::AIProxyError; - +use ndarray::{Ix2, Ix3}; pub struct OnnxOutputTransform { - output_key: String + output_key: String, } impl OnnxOutputTransform { @@ -17,32 +16,36 @@ impl Postprocessor for OnnxOutputTransform { fn process(&self, data: PostprocessorData) -> Result { match data { PostprocessorData::OnnxOutput(onnx_output) => { - let output = onnx_output.get(self.output_key.as_str()) - .ok_or_else(|| AIProxyError::OnnxOutputTransformError { - message: format!("Output key '{}' not found in the OnnxOutput.", self.output_key), - })?; - let output = output.try_extract_tensor::().map_err( - |_| AIProxyError::OnnxOutputTransformError { + let output = onnx_output.get(self.output_key.as_str()).ok_or_else(|| { + AIProxyError::OnnxOutputTransformError { + message: format!( + "Output key '{}' not found in the OnnxOutput.", + self.output_key + ), + } + })?; + let output = output.try_extract_tensor::().map_err(|_| { + AIProxyError::OnnxOutputTransformError { message: "Failed to extract tensor from OnnxOutput.".to_string(), } - )?; + })?; match output.ndim() { 2 => { - let output = output.into_dimensionality::().map_err( - |_| AIProxyError::OnnxOutputTransformError { + let output = output.into_dimensionality::().map_err(|_| { + AIProxyError::OnnxOutputTransformError { message: "Failed to convert Dyn tensor to 2D array.".to_string(), } - )?; + })?; Ok(PostprocessorData::NdArray2(output.to_owned())) - }, + } 3 => { - let output = output.into_dimensionality::().map_err( - |_| AIProxyError::OnnxOutputTransformError { + let output = output.into_dimensionality::().map_err(|_| { + AIProxyError::OnnxOutputTransformError { message: "Failed to convert Dyn tensor to 3D array.".to_string(), } - )?; + })?; Ok(PostprocessorData::NdArray3(output.to_owned())) - }, + } _ => Err(AIProxyError::OnnxOutputTransformError { message: "Only 2D and 3D tensors are supported.".to_string(), }), @@ -53,4 +56,4 @@ impl Postprocessor for OnnxOutputTransform { }), } } -} \ No newline at end of file +} diff --git a/ahnlich/ai/src/engine/ai/providers/processors/pooling.rs b/ahnlich/ai/src/engine/ai/providers/processors/pooling.rs index 9484b7b2..ea4c8728 100644 --- a/ahnlich/ai/src/engine/ai/providers/processors/pooling.rs +++ b/ahnlich/ai/src/engine/ai/providers/processors/pooling.rs @@ -1,6 +1,6 @@ -use ndarray::{s, Array, Array2, ArrayView, Dim, Axis, Dimension, IxDynImpl, Ix2}; use crate::engine::ai::providers::processors::{Postprocessor, PostprocessorData}; use crate::error::AIProxyError; +use ndarray::{s, Array, Axis, Ix2}; pub enum Pooling { Regular(RegularPooling), @@ -16,20 +16,24 @@ impl Postprocessor for RegularPooling { let processed = array.slice(s![.., 0, ..]).to_owned(); Ok(PostprocessorData::NdArray2(processed)) } + PostprocessorData::NdArray2(array) => Ok(PostprocessorData::NdArray2(array)), _ => Err(AIProxyError::PoolingError { - message: "Expected NdArray3, got NdArray2".to_string(), + message: "Expected NdArray3, NdArray2".to_string(), }), } } } +#[derive(Default)] pub struct MeanPooling { - attention_mask: Option> + attention_mask: Option>, } impl MeanPooling { pub fn new() -> Self { - Self { attention_mask: None } + Self { + attention_mask: None, + } } pub fn set_attention_mask(&mut self, attention_mask: Option>) { @@ -46,13 +50,17 @@ impl Postprocessor for MeanPooling { let attention_mask = mask.mapv(|x| x as f32); attention_mask .insert_axis(Axis(2)) - .broadcast(array.dim()).ok_or( - AIProxyError::PoolingError { - message: format!("Could not broadcast attention mask with shape {:?} to \ - shape {:?} of the input tensor.", mask.shape(), array.shape()), - } - )?.to_owned() - }, + .broadcast(array.dim()) + .ok_or(AIProxyError::PoolingError { + message: format!( + "Could not broadcast attention mask with shape {:?} to \ + shape {:?} of the input tensor.", + mask.shape(), + array.shape() + ), + })? + .to_owned() + } None => Array::ones(array.dim()), }; @@ -61,10 +69,13 @@ impl Postprocessor for MeanPooling { let attention_mask_sum = attention_mask.sum_axis(Axis(1)); let min_value = 1e-9; let attention_mask_sum = attention_mask_sum.mapv(|x| x.max(min_value)); - Ok(PostprocessorData::NdArray2(&masked_array_sum / &attention_mask_sum)) + Ok(PostprocessorData::NdArray2( + &masked_array_sum / &attention_mask_sum, + )) } + PostprocessorData::NdArray2(array) => Ok(PostprocessorData::NdArray2(array)), _ => Err(AIProxyError::PoolingError { - message: "Expected NdArray3, got NdArray2".to_string(), + message: "Expected NdArray3, NdArray2".to_string(), }), } } diff --git a/ahnlich/ai/src/engine/ai/providers/processors/postprocessor.rs b/ahnlich/ai/src/engine/ai/providers/processors/postprocessor.rs index 500fd5e0..bbe28018 100644 --- a/ahnlich/ai/src/engine/ai/providers/processors/postprocessor.rs +++ b/ahnlich/ai/src/engine/ai/providers/processors/postprocessor.rs @@ -1,12 +1,12 @@ -use std::sync::{Arc, Mutex}; -use ndarray::{Array, Ix2, Ix3}; -use ort::SessionOutputs; use crate::cli::server::SupportedModels; use crate::engine::ai::providers::processors::normalize::VectorNormalize; -use crate::engine::ai::providers::processors::pooling::{Pooling, RegularPooling, MeanPooling}; -use crate::engine::ai::providers::processors::{PostprocessorData, Postprocessor}; use crate::engine::ai::providers::processors::onnx_output_transform::OnnxOutputTransform; +use crate::engine::ai::providers::processors::pooling::{MeanPooling, Pooling, RegularPooling}; +use crate::engine::ai::providers::processors::{Postprocessor, PostprocessorData}; use crate::error::AIProxyError; +use ndarray::{Array, Ix2}; +use ort::SessionOutputs; +use std::sync::{Arc, Mutex}; pub enum ORTPostprocessor { Image(ORTImagePostprocessor), @@ -15,64 +15,79 @@ pub enum ORTPostprocessor { pub struct ORTTextPostprocessor { model: SupportedModels, + onnx_output_transform: OnnxOutputTransform, pooling: Arc>, - normalize: Option + normalize: Option, } - impl ORTTextPostprocessor { pub fn load(supported_model: SupportedModels) -> Result { + let output_transform = match supported_model { + SupportedModels::AllMiniLML6V2 + | SupportedModels::AllMiniLML12V2 + | SupportedModels::BGEBaseEnV15 + | SupportedModels::BGELargeEnV15 => { + OnnxOutputTransform::new("last_hidden_state".to_string()) + } + SupportedModels::ClipVitB32Text => OnnxOutputTransform::new("text_embeds".to_string()), + _ => Err(AIProxyError::ModelPostprocessingError { + model_name: supported_model.to_string(), + message: "Unsupported model for ORTTextPostprocessor".to_string(), + })?, + }; let ops = match supported_model { - SupportedModels::AllMiniLML6V2 | - SupportedModels::AllMiniLML12V2 => Ok(( - Pooling::Mean(MeanPooling::new()), - Some(VectorNormalize) - )), - SupportedModels::BGEBaseEnV15 | - SupportedModels::BGELargeEnV15 => Ok(( - Pooling::Regular(RegularPooling), - Some(VectorNormalize) - )), - SupportedModels::ClipVitB32Text => Ok(( - Pooling::Mean(MeanPooling::new()), - None - )), + SupportedModels::AllMiniLML6V2 | SupportedModels::AllMiniLML12V2 => { + Ok((Pooling::Mean(MeanPooling::new()), Some(VectorNormalize))) + } + SupportedModels::BGEBaseEnV15 | SupportedModels::BGELargeEnV15 => { + Ok((Pooling::Regular(RegularPooling), Some(VectorNormalize))) + } + SupportedModels::ClipVitB32Text => Ok((Pooling::Mean(MeanPooling::new()), None)), _ => Err(AIProxyError::ModelPostprocessingError { model_name: supported_model.to_string(), message: "Unsupported model for ORTTextPostprocessor".to_string(), - }) + }), }?; Ok(Self { model: supported_model, + onnx_output_transform: output_transform, pooling: Arc::new(Mutex::new(ops.0)), - normalize: ops.1 + normalize: ops.1, }) } - pub fn process(&self, embeddings: Array, attention_mask: Array) -> Result, AIProxyError> { - let mut pooling = self.pooling.lock().map_err(|_| AIProxyError::ModelPostprocessingError { - model_name: self.model.to_string(), - message: "Failed to acquire lock on pooling.".to_string() - })?; + pub fn process( + &self, + session_outputs: SessionOutputs, + attention_mask: Array, + ) -> Result, AIProxyError> { + let embeddings = self + .onnx_output_transform + .process(PostprocessorData::OnnxOutput(session_outputs))?; + let mut pooling = + self.pooling + .lock() + .map_err(|_| AIProxyError::ModelPostprocessingError { + model_name: self.model.to_string(), + message: "Failed to acquire lock on pooling.".to_string(), + })?; let pooled = match &mut *pooling { - Pooling::Regular(pooling) => { - pooling.process(PostprocessorData::NdArray3(embeddings))? - }, + Pooling::Regular(pooling) => pooling.process(embeddings)?, Pooling::Mean(pooling) => { pooling.set_attention_mask(Some(attention_mask)); - pooling.process(PostprocessorData::NdArray3(embeddings))? + pooling.process(embeddings)? } }; let result = match &self.normalize { Some(normalize) => normalize.process(pooled), - None => Ok(pooled) + None => Ok(pooled), }?; match result { PostprocessorData::NdArray2(array) => Ok(array), _ => Err(AIProxyError::ModelPostprocessingError { model_name: self.model.to_string(), - message: "Only returns NdArray2".to_string() - }) + message: "Only returns NdArray2".to_string(), + }), } } } @@ -80,47 +95,52 @@ impl ORTTextPostprocessor { pub struct ORTImagePostprocessor { model: SupportedModels, onnx_output_transform: OnnxOutputTransform, - normalize: Option + normalize: Option, } impl ORTImagePostprocessor { pub fn load(supported_model: SupportedModels) -> Result { let output_transform = match supported_model { - SupportedModels::Resnet50 | - SupportedModels::ClipVitB32Image => - OnnxOutputTransform::new("image_embeds".to_string()), + SupportedModels::Resnet50 | SupportedModels::ClipVitB32Image => { + OnnxOutputTransform::new("image_embeds".to_string()) + } _ => Err(AIProxyError::ModelPostprocessingError { model_name: supported_model.to_string(), - message: "Unsupported model for ORTImagePostprocessor".to_string() - })? + message: "Unsupported model for ORTImagePostprocessor".to_string(), + })?, }; let normalize = match supported_model { SupportedModels::Resnet50 => Ok(Some(VectorNormalize)), SupportedModels::ClipVitB32Image => Ok(None), _ => Err(AIProxyError::ModelPostprocessingError { model_name: supported_model.to_string(), - message: "Unsupported model for ORTImagePostprocessor".to_string() - }) + message: "Unsupported model for ORTImagePostprocessor".to_string(), + }), }?; Ok(Self { model: supported_model, normalize, - onnx_output_transform: output_transform + onnx_output_transform: output_transform, }) } - pub fn process(&self, session_outputs: SessionOutputs) -> Result, AIProxyError> { - let embeddings = self.onnx_output_transform.process(PostprocessorData::OnnxOutput(session_outputs))?; + pub fn process( + &self, + session_outputs: SessionOutputs, + ) -> Result, AIProxyError> { + let embeddings = self + .onnx_output_transform + .process(PostprocessorData::OnnxOutput(session_outputs))?; let result = match &self.normalize { Some(normalize) => normalize.process(embeddings), - None => Ok(embeddings) + None => Ok(embeddings), }?; match result { PostprocessorData::NdArray2(array) => Ok(array), _ => Err(AIProxyError::ModelPostprocessingError { model_name: self.model.to_string(), - message: "Only returns NdArray2".to_string() - }) + message: "Only returns NdArray2".to_string(), + }), } } -} \ No newline at end of file +} diff --git a/ahnlich/ai/src/engine/ai/providers/processors/preprocessor.rs b/ahnlich/ai/src/engine/ai/providers/processors/preprocessor.rs index a78db697..a7bfe5d2 100644 --- a/ahnlich/ai/src/engine/ai/providers/processors/preprocessor.rs +++ b/ahnlich/ai/src/engine/ai/providers/processors/preprocessor.rs @@ -1,183 +1,169 @@ -use std::iter; -use std::sync::{Arc, Mutex}; -use hf_hub::api::sync::ApiRepo; -use ndarray::{Array, Ix4}; -use tokenizers::{Encoding, Tokenizer}; +use crate::cli::server::SupportedModels; use crate::engine::ai::models::ImageArray; use crate::engine::ai::providers::ort_helper::HFConfigReader; use crate::engine::ai::providers::processors::center_crop::CenterCrop; use crate::engine::ai::providers::processors::imagearray_to_ndarray::ImageArrayToNdArray; use crate::engine::ai::providers::processors::normalize::ImageNormalize; -use crate::engine::ai::providers::processors::{Preprocessor, PreprocessorData}; use crate::engine::ai::providers::processors::rescale::Rescale; use crate::engine::ai::providers::processors::resize::Resize; -use crate::engine::ai::providers::processors::tokenize::Tokenize; +use crate::engine::ai::providers::processors::tokenize::{Tokenize, TokenizerFiles}; +use crate::engine::ai::providers::processors::{Preprocessor, PreprocessorData}; use crate::error::AIProxyError; - -#[derive(Clone)] -pub struct ImagePreprocessorFiles { - pub resize: Option, - pub normalize: Option, - pub rescale: Option, - pub center_crop: Option, -} - -impl ImagePreprocessorFiles { - pub fn iter(&self) -> impl Iterator { - iter::empty() - .chain(self.resize.as_ref().map( - |n| ("resize", n.as_str()))) - .chain(self.normalize.as_ref().map( - |n| ("normalize", n.as_str()))) - .chain(self.rescale.as_ref().map( - |n| ("rescale", n.as_str()))) - .chain(self.center_crop.as_ref().map( - |n| ("center_crop", n.as_str()))) - } -} - -impl Default for ImagePreprocessorFiles { - fn default() -> Self { - Self { - normalize: Some("preprocessor_config.json".to_string()), - resize: Some("preprocessor_config.json".to_string()), - rescale: Some("preprocessor_config.json".to_string()), - center_crop: Some("preprocessor_config.json".to_string()), - } - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct TokenizerFiles { - pub tokenizer_file: String, - pub config_file: String, - pub special_tokens_map_file: String, - pub tokenizer_config_file: String, -} - -impl Default for TokenizerFiles { - fn default() -> Self { - Self { - tokenizer_file: "tokenizer.json".to_string(), - config_file: "config.json".to_string(), - special_tokens_map_file: "special_tokens_map.json".to_string(), - tokenizer_config_file: "tokenizer_config.json".to_string(), - } - } -} - -#[derive(Default, Clone)] -pub struct TextPreprocessorFiles { - pub tokenize: TokenizerFiles, -} +use hf_hub::api::sync::ApiRepo; +use ndarray::{Array, Ix4}; +use std::sync::{Arc, Mutex}; +use tokenizers::Encoding; pub enum ORTPreprocessor { Image(ORTImagePreprocessor), Text(ORTTextPreprocessor), } -#[derive(Default)] pub struct ORTImagePreprocessor { - imagearray_to_ndarray: Option>, - normalize: Option>, - resize: Option>, - rescale: Option>, - center_crop: Option>, + model: SupportedModels, + imagearray_to_ndarray: ImageArrayToNdArray, + normalize: Option, + resize: Option, + rescale: Option, + center_crop: Option, } impl ORTImagePreprocessor { - pub fn iter(&self) -> impl Iterator)> { - iter::empty() - .chain(self.resize.as_ref().map( - |f| ("resize", f))) - .chain(self.center_crop.as_ref().map( - |f| ("center_crop", f))) - .chain(self.imagearray_to_ndarray.as_ref().map( - |f| ("imagearray_to_ndarray", f))) - .chain(self.rescale.as_ref().map( - |f| ("rescale", f))) - .chain(self.normalize.as_ref().map( - |f| ("normalize", f))) - } - - pub fn load(&mut self, model_repo: ApiRepo, processor_files: ImagePreprocessorFiles) -> Result<(), AIProxyError> { - let mut type_and_configs: Vec<(&str, Option)> = vec![ - ("imagearray_to_ndarray", None) - ]; + pub fn load( + supported_model: SupportedModels, + model_repo: ApiRepo, + ) -> Result { + let imagearray_to_ndarray = ImageArrayToNdArray; let mut config_reader = HFConfigReader::new(model_repo); - for data in processor_files.iter() { - type_and_configs.push((data.0, Some(config_reader.read(data.1)?))); - } - for (processor_type, config) in type_and_configs { - match processor_type { - "imagearray_to_ndarray" => { - self.imagearray_to_ndarray = Some(Box::new(ImageArrayToNdArray)); - } - "resize" => { - self.resize = Some(Box::new(Resize::try_from(&config.expect("Config exists"))?)); - } - "normalize" => { - self.normalize = Some(Box::new(ImageNormalize::try_from(&config.expect("Config exists"))?)); - } - "rescale" => { - self.rescale = Some(Box::new(Rescale::try_from(&config.expect("Config exists"))?)); - } - "center_crop" => { - self.center_crop = Some(Box::new(CenterCrop::try_from(&config.expect("Config exists"))?)); - } - _ => return Err(AIProxyError::ModelProviderPreprocessingError( - format!("The {} operation not found in ImagePreprocessor.", processor_type) - )) - } - } - Ok(()) + let config = config_reader.read("preprocessor_config.json")?; + + let resize = Resize::initialize(&config)?; + let center_crop = CenterCrop::initialize(&config)?; + let rescale = Rescale::initialize(&config)?; + let normalize = ImageNormalize::initialize(&config)?; + + Ok(Self { + model: supported_model, + imagearray_to_ndarray, + normalize, + resize, + rescale, + center_crop, + }) } pub fn process(&self, data: Vec) -> Result, AIProxyError> { let mut data = PreprocessorData::ImageArray(data); - for (_, processor) in self.iter() { - data = processor.process(data)?; - } + data = match self.resize { + Some(ref resize) => resize.process(data).map_err( + |e| AIProxyError::ModelPreprocessingError { + model_name: self.model.to_string(), + message: format!("Failed to process resize: {}", e), + }, + )?, + None => data, + }; + + data = match self.center_crop { + Some(ref center_crop) => center_crop.process(data).map_err( + |e| AIProxyError::ModelPreprocessingError { + model_name: self.model.to_string(), + message: format!("Failed to process center crop: {}", e), + }, + )?, + None => data, + }; + + data = self.imagearray_to_ndarray.process(data).map_err( + |e| AIProxyError::ModelPreprocessingError { + model_name: self.model.to_string(), + message: format!("Failed to process imagearray to ndarray: {}", e), + }, + )?; + + data = match self.rescale { + Some(ref rescale) => rescale.process(data).map_err( + |e| AIProxyError::ModelPreprocessingError { + model_name: self.model.to_string(), + message: format!("Failed to process rescale: {}", e), + }, + )?, + None => data, + }; + + data = match self.normalize { + Some(ref normalize) => normalize.process(data).map_err( + |e| AIProxyError::ModelPreprocessingError { + model_name: self.model.to_string(), + message: format!("Failed to process normalize: {}", e), + }, + )?, + None => data, + }; + match data { PreprocessorData::NdArray3C(array) => Ok(array), - _ => Err(AIProxyError::ModelProviderPreprocessingError( - "Expected NdArray after processing".to_string() - )) + _ => Err(AIProxyError::ModelPreprocessingError { + model_name: self.model.to_string(), + message: "Expected NdArray3C after processing".to_string(), + }), } } } pub struct ORTTextPreprocessor { - pub tokenize: Arc> + model: SupportedModels, + tokenize: Arc>, } impl ORTTextPreprocessor { - pub fn load(model_repo: ApiRepo, processor_files: TextPreprocessorFiles) -> Result { - Ok( - ORTTextPreprocessor { - tokenize: Arc::new(Mutex::new( - Tokenize::initialize(processor_files.tokenize, model_repo)?, - )), - } - ) + pub fn load( + supported_models: SupportedModels, + model_repo: ApiRepo, + ) -> Result { + let tokenizer_files = TokenizerFiles { + tokenizer_file: "tokenizer.json".to_string(), + config_file: "config.json".to_string(), + special_tokens_map_file: "special_tokens_map.json".to_string(), + tokenizer_config_file: "tokenizer_config.json".to_string(), + }; + + Ok(ORTTextPreprocessor { + model: supported_models, + tokenize: Arc::new(Mutex::new(Tokenize::initialize( + tokenizer_files, + model_repo, + )?)), + }) } - pub fn process(&self, data: Vec, truncate: bool) -> Result, AIProxyError> { + pub fn process( + &self, + data: Vec, + truncate: bool, + ) -> Result, AIProxyError> { let mut data = PreprocessorData::Text(data); let mut tokenize = self.tokenize.lock().map_err(|_| { - AIProxyError::ModelProviderPreprocessingError( - "Failed to acquire lock on tokenizer".to_string(), - ) + AIProxyError::ModelPreprocessingError { + model_name: self.model.to_string(), + message: "Failed to acquire lock on tokenize.".to_string(), + } })?; - tokenize.set_truncate(truncate); - data = tokenize.process(data)?; + let _ = tokenize.set_truncate(truncate); + data = tokenize.process(data).map_err( + |e| AIProxyError::ModelPreprocessingError { + model_name: self.model.to_string(), + message: format!("Failed to process tokenize: {}", e), + }, + )?; + match data { PreprocessorData::EncodedText(encodings) => Ok(encodings), - _ => Err(AIProxyError::ModelProviderPreprocessingError( - "Expected EncodedText after processing".to_string() - )) + _ => Err(AIProxyError::ModelPreprocessingError { + model_name: self.model.to_string(), + message: "Expected EncodedText after processing".to_string(), + }), } } -} \ No newline at end of file +} diff --git a/ahnlich/ai/src/engine/ai/providers/processors/rescale.rs b/ahnlich/ai/src/engine/ai/providers/processors/rescale.rs index 48e74bf6..677a536f 100644 --- a/ahnlich/ai/src/engine/ai/providers/processors/rescale.rs +++ b/ahnlich/ai/src/engine/ai/providers/processors/rescale.rs @@ -3,43 +3,31 @@ use crate::error::AIProxyError; pub struct Rescale { scale: f32, - process: bool } -impl TryFrom<&serde_json::Value> for Rescale { - type Error = AIProxyError; - - fn try_from(config: &serde_json::Value) -> Result { +impl Rescale { + pub fn initialize(config: &serde_json::Value) -> Result, AIProxyError> { if !config["do_rescale"].as_bool().unwrap_or(true) { - return Ok( - Self { - scale: 0f32, - process: false - } - ); + return Ok(None); } - let default_scale = 1.0/255.0; + let default_scale = 1.0 / 255.0; let scale = config["rescale_factor"].as_f64().unwrap_or(default_scale) as f32; - Ok(Self { scale, process: true }) + Ok(Some(Self { scale })) } } impl Preprocessor for Rescale { fn process(&self, data: PreprocessorData) -> Result { - if !self.process { - return Ok(data); - } - match data { PreprocessorData::NdArray3C(array) => { let mut array = array; array *= self.scale; Ok(PreprocessorData::NdArray3C(array)) - }, + } _ => Err(AIProxyError::RescaleError { - message: "Rescale process failed. Expected NdArray3C, got ImageArray".to_string(), + message: "Rescale process failed. Expected NdArray3C.".to_string(), }), } } -} \ No newline at end of file +} diff --git a/ahnlich/ai/src/engine/ai/providers/processors/resize.rs b/ahnlich/ai/src/engine/ai/providers/processors/resize.rs index e46dbb53..ea7d743a 100644 --- a/ahnlich/ai/src/engine/ai/providers/processors/resize.rs +++ b/ahnlich/ai/src/engine/ai/providers/processors/resize.rs @@ -1,31 +1,24 @@ -use image::imageops::FilterType; -use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator}; use crate::engine::ai::models::ImageArray; -use crate::engine::ai::providers::processors::{CONV_NEXT_FEATURE_EXTRACTOR_CENTER_CROP_THRESHOLD, Preprocessor, PreprocessorData}; +use crate::engine::ai::providers::processors::{ + Preprocessor, PreprocessorData, CONV_NEXT_FEATURE_EXTRACTOR_CENTER_CROP_THRESHOLD, +}; use crate::error::AIProxyError; +use image::imageops::FilterType; +use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator}; pub struct Resize { size: (u32, u32), // (width, height) resample: FilterType, - process: bool } -impl TryFrom<&serde_json::Value> for Resize { - type Error = AIProxyError; - - fn try_from(config: &serde_json::Value) -> Result { - // let config = SafeValue::new(config.to_owned()); +impl Resize { + pub fn initialize(config: &serde_json::Value) -> Result, AIProxyError> { if !config["do_resize"].as_bool().unwrap_or(false) { - return Ok( - Self { - size: (0, 0), - resample: FilterType::CatmullRom, - process: false - } - ); + return Ok(None); } - let image_processor_type = config["image_processor_type"].as_str() + let image_processor_type = config["image_processor_type"] + .as_str() .unwrap_or("CLIPImageProcessor"); let size = &config["size"]; @@ -41,45 +34,61 @@ impl TryFrom<&serde_json::Value> for Resize { let shortest_edge = &size["shortest_edge"]; let size_width = &size["width"]; let size_height = &size["height"]; - let has_value = shortest_edge.is_u64() || - (size_width.is_u64() && size_height.is_u64() + let has_value = shortest_edge.is_u64() + || (size_width.is_u64() + && size_height.is_u64() && image_processor_type == "CLIPImageProcessor"); if !has_value { return Err(AIProxyError::ModelConfigLoadError { message: "The ['size'] section of the configuration must contain either a \ 'shortest_edge' mapping or 'width' and 'height' mappings (when \ 'image_processor_type' is 'CLIPImageProcessor'); they should be \ - integers.".to_string(), + integers." + .to_string(), }); } if shortest_edge.is_u64() { - width = shortest_edge.as_u64().expect("It will always be an integer here.") as u32; + width = shortest_edge + .as_u64() + .expect("It will always be an integer here.") as u32; height = width; } else { - width = size_width.as_u64().expect("It will always be an integer here.") as u32; - height = size_height.as_u64().expect("It will always be an integer here.") as u32; + width = size_width + .as_u64() + .expect("It will always be an integer here.") as u32; + height = size_height + .as_u64() + .expect("It will always be an integer here.") as u32; } match image_processor_type { - "CLIPImageProcessor" => { - Ok(Self { size: (width, height), resample: FilterType::CatmullRom, process: true }) - }, + "CLIPImageProcessor" => Ok(Some(Self { + size: (width, height), + resample: FilterType::CatmullRom, + })), "ConvNextFeatureExtractor" => { if width >= CONV_NEXT_FEATURE_EXTRACTOR_CENTER_CROP_THRESHOLD { - Ok(Self { size: (width, height), resample: FilterType::CatmullRom, process: true - }) + Ok(Some(Self { + size: (width, height), + resample: FilterType::CatmullRom, + })) } else { let default_crop_pct = 0.875; let crop_pct = config["crop_pct"].as_f64().unwrap_or(default_crop_pct) as f32; let upsampled_edge = (width as f32 / crop_pct) as u32; - Ok(Self { size: (upsampled_edge, upsampled_edge), resample: FilterType::CatmullRom, - process: true }) + Ok(Some(Self { + size: (upsampled_edge, upsampled_edge), + resample: FilterType::CatmullRom, + })) } - }, + } _ => Err(AIProxyError::ModelConfigLoadError { - message: format!("Resize init failed. image_processor_type {} not supported", image_processor_type), - }) + message: format!( + "Resize init failed. image_processor_type {} not supported", + image_processor_type + ), + }), } } } @@ -88,12 +97,9 @@ impl Preprocessor for Resize { fn process(&self, data: PreprocessorData) -> Result { match data { PreprocessorData::ImageArray(mut arrays) => { - let processed = arrays.par_iter_mut() + let processed = arrays + .par_iter_mut() .map(|image| { - if !self.process { - return Ok(image.clone()); - } - let image = image.resize(self.size.0, self.size.1, Some(self.resample))?; Ok(image) }) @@ -105,4 +111,4 @@ impl Preprocessor for Resize { }), } } -} \ No newline at end of file +} diff --git a/ahnlich/ai/src/engine/ai/providers/processors/tokenize.rs b/ahnlich/ai/src/engine/ai/providers/processors/tokenize.rs index 0095e86e..bf1eaf10 100644 --- a/ahnlich/ai/src/engine/ai/providers/processors/tokenize.rs +++ b/ahnlich/ai/src/engine/ai/providers/processors/tokenize.rs @@ -1,17 +1,15 @@ +use crate::engine::ai::providers::ort_helper::{read_file_to_bytes, HFConfigReader}; +use crate::engine::ai::providers::processors::{Preprocessor, PreprocessorData}; +use crate::error::AIProxyError; use hf_hub::api::sync::ApiRepo; use serde_json::Value; -use tokenizers::{AddedToken, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams}; use tokenizers::decoders::bpe::BPEDecoder; -use crate::engine::ai::providers::ort_helper::{HFConfigReader, read_file_to_bytes}; -use crate::engine::ai::providers::processors::preprocessor::TokenizerFiles; -use crate::engine::ai::providers::processors::{Preprocessor, PreprocessorData}; -use crate::error::AIProxyError; - +use tokenizers::{AddedToken, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams}; pub struct Tokenize { tokenizer: Tokenizer, model_max_length: usize, - truncate: bool + truncate: bool, } pub struct TokenizeArtifacts { @@ -22,12 +20,21 @@ pub struct TokenizeArtifacts { } impl Tokenize { - pub fn download_artifacts(tokenizer_files: TokenizerFiles, model_repo: ApiRepo) -> Result { + pub fn download_artifacts( + tokenizer_files: TokenizerFiles, + model_repo: ApiRepo, + ) -> Result { let tokenizer_bytes = read_file_to_bytes( - &model_repo.get(&tokenizer_files.tokenizer_file) - .map_err(|e| AIProxyError::ModelConfigLoadError{ - message: format!("failed to fetch {}, {}", &tokenizer_files.tokenizer_file, e.to_string()), - })?)?; + &model_repo + .get(&tokenizer_files.tokenizer_file) + .map_err(|e| AIProxyError::ModelConfigLoadError { + message: format!( + "failed to fetch {}, {}", + &tokenizer_files.tokenizer_file, + e + ), + })?, + )?; let mut config_reader = HFConfigReader::new(model_repo); let config = config_reader.read(&tokenizer_files.config_file)?; let special_tokens_map = config_reader.read(&tokenizer_files.special_tokens_map_file)?; @@ -36,16 +43,20 @@ impl Tokenize { tokenizer_bytes, config, special_tokens_map, - tokenizer_config + tokenizer_config, }) } - pub fn initialize(tokenizer_files: TokenizerFiles, model_repo: ApiRepo) -> Result { + pub fn initialize( + tokenizer_files: TokenizerFiles, + model_repo: ApiRepo, + ) -> Result { let artifacts = Self::download_artifacts(tokenizer_files, model_repo)?; - let mut tokenizer = - Tokenizer::from_bytes(artifacts.tokenizer_bytes).map_err(|_| AIProxyError::ModelTokenizerLoadError { + let mut tokenizer = Tokenizer::from_bytes(artifacts.tokenizer_bytes).map_err(|_| { + AIProxyError::ModelTokenizerLoadError { message: "Error building Tokenizer from bytes.".to_string(), - })?; + } + })?; //For BGEBaseSmall, the model_max_length value is set to 1000000000000000019884624838656. Which fits in a f64 let model_max_length = artifacts.tokenizer_config["model_max_length"] @@ -101,25 +112,30 @@ impl Tokenize { let decoder = BPEDecoder::new("".to_string()); tokenizer.with_decoder(Some(decoder)); - Ok(Self { tokenizer: tokenizer.into(), model_max_length, truncate: true }) + Ok(Self { + tokenizer: tokenizer.into(), + model_max_length, + truncate: true, + }) } pub fn set_truncate(&mut self, truncate: bool) -> Result<(), AIProxyError> { - let tokenizer; - if truncate { - tokenizer = self.tokenizer.with_truncation( - Some(TruncationParams { - max_length: self.model_max_length, - ..Default::default() - })).map_err(|_| AIProxyError::ModelTokenizerLoadError { + let tokenizer = if truncate { + self.tokenizer + .with_truncation(Some(TruncationParams { + max_length: self.model_max_length, + ..Default::default() + })) + .map_err(|_| AIProxyError::ModelTokenizerLoadError { message: "Error setting truncation params.".to_string(), - })?; + })? } else { - tokenizer = self.tokenizer.with_truncation(None) - .map_err(|_| AIProxyError::ModelTokenizerLoadError { + self.tokenizer.with_truncation(None).map_err(|_| { + AIProxyError::ModelTokenizerLoadError { message: "Error removing truncation params.".to_string(), - })?; - } + } + })? + }; self.truncate = truncate; self.tokenizer = tokenizer.clone().into(); Ok(()) @@ -130,15 +146,24 @@ impl Preprocessor for Tokenize { fn process(&self, data: PreprocessorData) -> Result { match data { PreprocessorData::Text(text) => { - let tokenized = self.tokenizer.encode_batch(text.clone(), true) + let tokenized = self + .tokenizer + .encode_batch(text.clone(), true) .map_err(|_| AIProxyError::ModelTokenizationError { message: format!("Tokenize process failed. Texts: {:?}", text), })?; Ok(PreprocessorData::EncodedText(tokenized)) - }, + } _ => Err(AIProxyError::ModelTokenizationError { - message: format!("Tokenize process failed. Expected Text."), + message: "Tokenize process failed. Expected Text.".to_string(), }), } } } + +pub struct TokenizerFiles { + pub tokenizer_file: String, + pub config_file: String, + pub special_tokens_map_file: String, + pub tokenizer_config_file: String, +} diff --git a/ahnlich/ai/src/error.rs b/ahnlich/ai/src/error.rs index ef3ce57e..73dd827a 100644 --- a/ahnlich/ai/src/error.rs +++ b/ahnlich/ai/src/error.rs @@ -38,21 +38,13 @@ pub enum AIProxyError { }, #[error("Model preprocessing for {model_name} failed: {message}.")] - ModelPreprocessingError { - model_name: String, - message: String - }, + ModelPreprocessingError { model_name: String, message: String }, #[error("Model postprocessing for {model_name} failed: {message}.")] - ModelPostprocessingError { - model_name: String, - message: String - }, + ModelPostprocessingError { model_name: String, message: String }, #[error("Pooling operation failed: {message}.")] - PoolingError { - message: String - }, + PoolingError { message: String }, #[error( "Image Dimensions [({0}, {1})] does not match the expected model dimensions [({2}, {3})]", @@ -83,45 +75,31 @@ pub enum AIProxyError { CacheLocationNotInitiailized, #[error("index_model or query_model [{model_name}] not supported")] - AIModelNotSupported { - model_name: String - }, + AIModelNotSupported { model_name: String }, #[error("Invalid operation [{operation}] on model [{model_name}]")] AIModelInvalidOperation { operation: String, - model_name: String + model_name: String, }, #[error("Vector normalization error: [{message}]")] - VectorNormalizationError { - message: String - }, + VectorNormalizationError { message: String }, #[error("Image normalization error: [{message}]")] - ImageNormalizationError { - message: String - }, + ImageNormalizationError { message: String }, #[error("ImageArray to NdArray conversion error: [{message}]")] - ImageArrayToNdArrayError { - message: String - }, + ImageArrayToNdArrayError { message: String }, #[error("Onnx output transform error: [{message}]")] - OnnxOutputTransformError { - message: String - }, + OnnxOutputTransformError { message: String }, #[error("Rescale error: [{message}]")] - RescaleError { - message: String - }, + RescaleError { message: String }, #[error("Center crop error: [{message}]")] - CenterCropError { - message: String - }, + CenterCropError { message: String }, // TODO: Add SendError from mpsc::Sender into this variant #[error("Error sending request to model thread")] @@ -168,22 +146,16 @@ pub enum AIProxyError { ModelProviderPostprocessingError(String), #[error("Tokenize error: {message}")] - ModelTokenizationError { - message: String - }, + ModelTokenizationError { message: String }, #[error("Cannot call DelKey on store with `store_original` as false")] DelKeyError, #[error("Tokenizer for model failed to load: {message}")] - ModelTokenizerLoadError { - message: String - }, + ModelTokenizerLoadError { message: String }, #[error("Unable to load config: [{message}].")] - ModelConfigLoadError{ - message: String - } + ModelConfigLoadError { message: String }, } impl From for AIProxyError { diff --git a/ahnlich/ai/src/manager/mod.rs b/ahnlich/ai/src/manager/mod.rs index e99e31c6..6ac17be0 100644 --- a/ahnlich/ai/src/manager/mod.rs +++ b/ahnlich/ai/src/manager/mod.rs @@ -7,22 +7,20 @@ use crate::engine::ai::models::{ImageArray, InputAction}; /// lets AIProxyTasks communicate with any model to receive immediate responses via a oneshot /// channel use crate::engine::ai::models::{Model, ModelInput}; -use crate::engine::ai::providers::ModelProviders; -use crate::engine::ai::providers::processors::{Preprocessor, PreprocessorData}; use crate::engine::ai::providers::processors::imagearray_to_ndarray::ImageArrayToNdArray; - -use crate::engine::ai::providers::ort::ORTProvider; +use crate::engine::ai::providers::processors::{Preprocessor, PreprocessorData}; +use crate::engine::ai::providers::ModelProviders; use crate::error::AIProxyError; -use ahnlich_types::ai::{AIModel, AIStoreInputType, PreprocessAction}; +use ahnlich_types::ai::{AIModel, PreprocessAction}; use ahnlich_types::keyval::{StoreInput, StoreKey}; use fallible_collections::FallibleVec; use moka::future::Cache; use ndarray::{Array, Ix4}; use rayon::prelude::*; -use tokenizers::Encoding; use task_manager::Task; use task_manager::TaskManager; use task_manager::TaskState; +use tokenizers::Encoding; use tokio::sync::Mutex; use tokio::sync::{mpsc, oneshot}; use tokio::time::Duration; @@ -84,24 +82,34 @@ impl ModelThread { process_action: PreprocessAction, inputs: Vec, ) -> Result { - let sample = inputs.first().ok_or(AIProxyError::ModelPreprocessingError { - model_name: self.model.model_name(), - message: "Input is empty".to_string(), - })?; + let sample = inputs + .first() + .ok_or(AIProxyError::ModelPreprocessingError { + model_name: self.model.model_name(), + message: "Input is empty".to_string(), + })?; match sample { StoreInput::RawString(_) => { - let inputs: Vec = inputs.into_par_iter().filter_map(|input| match input { - StoreInput::RawString(string) => Some(string), - _ => None, - }).collect(); + let inputs: Vec = inputs + .into_par_iter() + .filter_map(|input| match input { + StoreInput::RawString(string) => Some(string), + _ => None, + }) + .collect(); let output = self.preprocess_raw_string(inputs, process_action)?; Ok(ModelInput::Texts(output)) } StoreInput::Image(_) => { - let inputs = inputs.into_par_iter().filter_map(|input| match input { - StoreInput::Image(image_bytes) => Some(ImageArray::try_new(image_bytes).ok()?), - _ => None, - }).collect(); + let inputs = inputs + .into_par_iter() + .filter_map(|input| match input { + StoreInput::Image(image_bytes) => { + Some(ImageArray::try_new(image_bytes).ok()?) + } + _ => None, + }) + .collect(); let output = self.preprocess_image(inputs, process_action)?; Ok(ModelInput::Images(output)) } @@ -122,15 +130,15 @@ impl ModelThread { match &self.model.provider { ModelProviders::ORT(provider) => { - let truncate = match process_action { - PreprocessAction::ModelPreprocessing => true, - _ => false - }; + let truncate = matches!(process_action, PreprocessAction::ModelPreprocessing); let outputs = provider.preprocess_texts(inputs, truncate)?; - let token_size = outputs.first().ok_or(AIProxyError::ModelPreprocessingError { - model_name: self.model.model_name(), - message: "Processed output is empty".to_string(), - })?.len(); + let token_size = outputs + .first() + .ok_or(AIProxyError::ModelPreprocessingError { + model_name: self.model.model_name(), + message: "Processed output is empty".to_string(), + })? + .len(); if token_size > max_token_size { return Err(AIProxyError::TokenExceededError { max_token_size, @@ -147,27 +155,25 @@ impl ModelThread { fn preprocess_image( &self, inputs: Vec, - process_action: PreprocessAction + process_action: PreprocessAction, ) -> Result, AIProxyError> { // process image, return error if max dimensions exceeded - let (expected_width, expected_height) = self.model.expected_image_dimensions() - .ok_or(AIProxyError::ModelPreprocessingError { - model_name: self.model.model_name(), - message: "Image preprocessing is not supported.".to_string(), - })?; + let (expected_width, expected_height) = self.model.expected_image_dimensions().ok_or( + AIProxyError::ModelPreprocessingError { + model_name: self.model.model_name(), + message: "Image preprocessing is not supported.".to_string(), + }, + )?; let expected_width = usize::from(expected_width); let expected_height = usize::from(expected_height); match &self.model.provider { ModelProviders::ORT(provider) => { let outputs = match process_action { - PreprocessAction::ModelPreprocessing => { - provider.preprocess_images(inputs)? - } - PreprocessAction::NoPreprocessing => { - ImageArrayToNdArray.process(PreprocessorData::ImageArray(inputs))? - .into_ndarray3c()? - } + PreprocessAction::ModelPreprocessing => provider.preprocess_images(inputs)?, + PreprocessAction::NoPreprocessing => ImageArrayToNdArray + .process(PreprocessorData::ImageArray(inputs))? + .into_ndarray3c()?, }; let outputs_shape = outputs.shape(); let width = *outputs_shape.get(2).expect("Must exist"); @@ -175,7 +181,7 @@ impl ModelThread { if width != expected_width || height != expected_height { return Err(AIProxyError::ImageDimensionsMismatchError { image_dimensions: (width, height), - expected_dimensions: (expected_width.into(), expected_height.into()), + expected_dimensions: (expected_width, expected_height), }); } else { return Ok(outputs); diff --git a/ahnlich/ai/src/server/task.rs b/ahnlich/ai/src/server/task.rs index 8810bf41..95fb4d9e 100644 --- a/ahnlich/ai/src/server/task.rs +++ b/ahnlich/ai/src/server/task.rs @@ -1,7 +1,7 @@ use crate::engine::ai::models::Model; use ahnlich_client_rs::db::DbClient; use ahnlich_types::ai::{ - AIQuery, AIServerQuery, AIServerResponse, AIServerResult, PreprocessAction + AIQuery, AIServerQuery, AIServerResponse, AIServerResult, PreprocessAction, }; use ahnlich_types::client::ConnectedClient; use ahnlich_types::db::{ServerInfo, ServerResponse}; @@ -365,17 +365,17 @@ impl AhnlichProtocol for AIProxyTask { .collect(), )) } else { - Err(AIProxyError::UnexpectedDBResponse(format!("{:?}", res)) - .to_string()) + Err(AIProxyError::UnexpectedDBResponse(format!( + "{:?}", + res + )) + .to_string()) } } Err(err) => Err(format!("{err}")), } } - Err(err) => Err( - AIProxyError::StandardError(err.to_string()) - .to_string(), - ), + Err(err) => Err(AIProxyError::StandardError(err.to_string()).to_string()), } } AIQuery::PurgeStores => { diff --git a/ahnlich/ai/src/tests/aiproxy_test.rs b/ahnlich/ai/src/tests/aiproxy_test.rs index 7f66d8a9..3fb2a533 100644 --- a/ahnlich/ai/src/tests/aiproxy_test.rs +++ b/ahnlich/ai/src/tests/aiproxy_test.rs @@ -3,7 +3,7 @@ use ahnlich_db::server::handler::Server; use ahnlich_types::{ ai::{ AIModel, AIQuery, AIServerQuery, AIServerResponse, AIServerResult, AIStoreInfo, - PreprocessAction + PreprocessAction, }, db::StoreUpsert, keyval::{StoreInput, StoreName, StoreValue}, diff --git a/ahnlich/types/src/ai/preprocess.rs b/ahnlich/types/src/ai/preprocess.rs index 31ff8a78..f86c1232 100644 --- a/ahnlich/types/src/ai/preprocess.rs +++ b/ahnlich/types/src/ai/preprocess.rs @@ -4,7 +4,7 @@ use std::fmt; #[derive(Copy, Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum PreprocessAction { NoPreprocessing, - ModelPreprocessing + ModelPreprocessing, } impl fmt::Display for PreprocessAction { From 66b968b32ecbb40e72a099842a9fdd2cdb5c656d Mon Sep 17 00:00:00 2001 From: HabeebShopeju Date: Tue, 26 Nov 2024 10:53:24 +0000 Subject: [PATCH 07/15] Ran type gen --- .../ahnlich_client_py/internals/ai_query.py | 140 ++---- .../internals/ai_response.py | 91 ++-- .../internals/bincode/__init__.py | 4 +- .../ahnlich_client_py/internals/db_query.py | 45 +- .../internals/db_response.py | 67 ++- .../internals/serde_binary/__init__.py | 2 +- .../internals/serde_types/__init__.py | 5 +- sdk/ahnlich-client-py/demo_embed.py | 446 +++++++++++------- type_specs/query/ai_query.json | 35 +- type_specs/response/ai_response.json | 3 + 10 files changed, 415 insertions(+), 423 deletions(-) diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_query.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_query.py index 37fb48e5..db9475c1 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_query.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_query.py @@ -1,10 +1,8 @@ # pyre-strict -import typing from dataclasses import dataclass - -from ahnlich_client_py.internals import bincode +import typing from ahnlich_client_py.internals import serde_types as st - +from ahnlich_client_py.internals import bincode class AIModel: VARIANTS = [] # type: typing.Sequence[typing.Type[AIModel]] @@ -13,10 +11,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIModel) @staticmethod - def bincode_deserialize(input: bytes) -> "AIModel": + def bincode_deserialize(input: bytes) -> 'AIModel': v, buffer = bincode.deserialize(input, AIModel) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -55,12 +53,12 @@ class AIModel__ClipVitB32Image(AIModel): INDEX = 5 # type: int pass + @dataclass(frozen=True) class AIModel__ClipVitB32Text(AIModel): INDEX = 6 # type: int pass - AIModel.VARIANTS = [ AIModel__AllMiniLML6V2, AIModel__AllMiniLML12V2, @@ -79,10 +77,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIQuery) @staticmethod - def bincode_deserialize(input: bytes) -> "AIQuery": + def bincode_deserialize(input: bytes) -> 'AIQuery': v, buffer = bincode.deserialize(input, AIQuery) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -149,9 +147,7 @@ class AIQuery__DropNonLinearAlgorithmIndex(AIQuery): class AIQuery__Set(AIQuery): INDEX = 7 # type: int store: str - inputs: typing.Sequence[ - typing.Tuple["StoreInput", typing.Dict[str, "MetadataValue"]] - ] + inputs: typing.Sequence[typing.Tuple["StoreInput", typing.Dict[str, "MetadataValue"]]] preprocess_action: "PreprocessAction" @@ -192,7 +188,6 @@ class AIQuery__Ping(AIQuery): INDEX = 13 # type: int pass - AIQuery.VARIANTS = [ AIQuery__CreateStore, AIQuery__GetPred, @@ -220,10 +215,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIServerQuery) @staticmethod - def bincode_deserialize(input: bytes) -> "AIServerQuery": + def bincode_deserialize(input: bytes) -> 'AIServerQuery': v, buffer = bincode.deserialize(input, AIServerQuery) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -234,10 +229,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIStoreInputType) @staticmethod - def bincode_deserialize(input: bytes) -> "AIStoreInputType": + def bincode_deserialize(input: bytes) -> 'AIStoreInputType': v, buffer = bincode.deserialize(input, AIStoreInputType) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -252,7 +247,6 @@ class AIStoreInputType__Image(AIStoreInputType): INDEX = 1 # type: int pass - AIStoreInputType.VARIANTS = [ AIStoreInputType__RawString, AIStoreInputType__Image, @@ -266,10 +260,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Algorithm) @staticmethod - def bincode_deserialize(input: bytes) -> "Algorithm": + def bincode_deserialize(input: bytes) -> 'Algorithm': v, buffer = bincode.deserialize(input, Algorithm) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -296,7 +290,6 @@ class Algorithm__KDTree(Algorithm): INDEX = 3 # type: int pass - Algorithm.VARIANTS = [ Algorithm__EuclideanDistance, Algorithm__DotProductSimilarity, @@ -305,38 +298,6 @@ class Algorithm__KDTree(Algorithm): ] -class ImageAction: - VARIANTS = [] # type: typing.Sequence[typing.Type[ImageAction]] - - def bincode_serialize(self) -> bytes: - return bincode.serialize(self, ImageAction) - - @staticmethod - def bincode_deserialize(input: bytes) -> "ImageAction": - v, buffer = bincode.deserialize(input, ImageAction) - if buffer: - raise st.DeserializationError("Some input bytes were not read") - return v - - -@dataclass(frozen=True) -class ImageAction__ResizeImage(ImageAction): - INDEX = 0 # type: int - pass - - -@dataclass(frozen=True) -class ImageAction__ErrorIfDimensionsMismatch(ImageAction): - INDEX = 1 # type: int - pass - - -ImageAction.VARIANTS = [ - ImageAction__ResizeImage, - ImageAction__ErrorIfDimensionsMismatch, -] - - class MetadataValue: VARIANTS = [] # type: typing.Sequence[typing.Type[MetadataValue]] @@ -344,10 +305,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, MetadataValue) @staticmethod - def bincode_deserialize(input: bytes) -> "MetadataValue": + def bincode_deserialize(input: bytes) -> 'MetadataValue': v, buffer = bincode.deserialize(input, MetadataValue) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -362,7 +323,6 @@ class MetadataValue__Image(MetadataValue): INDEX = 1 # type: int value: typing.Sequence[st.uint8] - MetadataValue.VARIANTS = [ MetadataValue__RawString, MetadataValue__Image, @@ -376,10 +336,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, NonLinearAlgorithm) @staticmethod - def bincode_deserialize(input: bytes) -> "NonLinearAlgorithm": + def bincode_deserialize(input: bytes) -> 'NonLinearAlgorithm': v, buffer = bincode.deserialize(input, NonLinearAlgorithm) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -388,7 +348,6 @@ class NonLinearAlgorithm__KDTree(NonLinearAlgorithm): INDEX = 0 # type: int pass - NonLinearAlgorithm.VARIANTS = [ NonLinearAlgorithm__KDTree, ] @@ -401,10 +360,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Predicate) @staticmethod - def bincode_deserialize(input: bytes) -> "Predicate": + def bincode_deserialize(input: bytes) -> 'Predicate': v, buffer = bincode.deserialize(input, Predicate) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -435,7 +394,6 @@ class Predicate__NotIn(Predicate): key: str value: typing.Sequence["MetadataValue"] - Predicate.VARIANTS = [ Predicate__Equals, Predicate__NotEquals, @@ -451,10 +409,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, PredicateCondition) @staticmethod - def bincode_deserialize(input: bytes) -> "PredicateCondition": + def bincode_deserialize(input: bytes) -> 'PredicateCondition': v, buffer = bincode.deserialize(input, PredicateCondition) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -475,7 +433,6 @@ class PredicateCondition__Or(PredicateCondition): INDEX = 2 # type: int value: typing.Tuple["PredicateCondition", "PredicateCondition"] - PredicateCondition.VARIANTS = [ PredicateCondition__Value, PredicateCondition__And, @@ -490,28 +447,27 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, PreprocessAction) @staticmethod - def bincode_deserialize(input: bytes) -> "PreprocessAction": + def bincode_deserialize(input: bytes) -> 'PreprocessAction': v, buffer = bincode.deserialize(input, PreprocessAction) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @dataclass(frozen=True) -class PreprocessAction__RawString(PreprocessAction): +class PreprocessAction__NoPreprocessing(PreprocessAction): INDEX = 0 # type: int - value: "StringAction" + pass @dataclass(frozen=True) -class PreprocessAction__Image(PreprocessAction): +class PreprocessAction__ModelPreprocessing(PreprocessAction): INDEX = 1 # type: int - value: "ImageAction" - + pass PreprocessAction.VARIANTS = [ - PreprocessAction__RawString, - PreprocessAction__Image, + PreprocessAction__NoPreprocessing, + PreprocessAction__ModelPreprocessing, ] @@ -522,10 +478,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StoreInput) @staticmethod - def bincode_deserialize(input: bytes) -> "StoreInput": + def bincode_deserialize(input: bytes) -> 'StoreInput': v, buffer = bincode.deserialize(input, StoreInput) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -540,40 +496,8 @@ class StoreInput__Image(StoreInput): INDEX = 1 # type: int value: typing.Sequence[st.uint8] - StoreInput.VARIANTS = [ StoreInput__RawString, StoreInput__Image, ] - -class StringAction: - VARIANTS = [] # type: typing.Sequence[typing.Type[StringAction]] - - def bincode_serialize(self) -> bytes: - return bincode.serialize(self, StringAction) - - @staticmethod - def bincode_deserialize(input: bytes) -> "StringAction": - v, buffer = bincode.deserialize(input, StringAction) - if buffer: - raise st.DeserializationError("Some input bytes were not read") - return v - - -@dataclass(frozen=True) -class StringAction__TruncateIfTokensExceed(StringAction): - INDEX = 0 # type: int - pass - - -@dataclass(frozen=True) -class StringAction__ErrorIfTokensExceed(StringAction): - INDEX = 1 # type: int - pass - - -StringAction.VARIANTS = [ - StringAction__TruncateIfTokensExceed, - StringAction__ErrorIfTokensExceed, -] diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_response.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_response.py index 9e8cda32..6b5fc608 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_response.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_response.py @@ -1,10 +1,8 @@ # pyre-strict -import typing from dataclasses import dataclass - -from ahnlich_client_py.internals import bincode +import typing from ahnlich_client_py.internals import serde_types as st - +from ahnlich_client_py.internals import bincode class AIModel: VARIANTS = [] # type: typing.Sequence[typing.Type[AIModel]] @@ -13,10 +11,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIModel) @staticmethod - def bincode_deserialize(input: bytes) -> "AIModel": + def bincode_deserialize(input: bytes) -> 'AIModel': v, buffer = bincode.deserialize(input, AIModel) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -56,6 +54,11 @@ class AIModel__ClipVitB32Image(AIModel): pass +@dataclass(frozen=True) +class AIModel__ClipVitB32Text(AIModel): + INDEX = 6 # type: int + pass + AIModel.VARIANTS = [ AIModel__AllMiniLML6V2, AIModel__AllMiniLML12V2, @@ -63,6 +66,7 @@ class AIModel__ClipVitB32Image(AIModel): AIModel__BGELargeEnV15, AIModel__Resnet50, AIModel__ClipVitB32Image, + AIModel__ClipVitB32Text, ] @@ -73,10 +77,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIServerResponse) @staticmethod - def bincode_deserialize(input: bytes) -> "AIServerResponse": + def bincode_deserialize(input: bytes) -> 'AIServerResponse': v, buffer = bincode.deserialize(input, AIServerResponse) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -119,21 +123,13 @@ class AIServerResponse__Set(AIServerResponse): @dataclass(frozen=True) class AIServerResponse__Get(AIServerResponse): INDEX = 6 # type: int - value: typing.Sequence[ - typing.Tuple[typing.Optional["StoreInput"], typing.Dict[str, "MetadataValue"]] - ] + value: typing.Sequence[typing.Tuple[typing.Optional["StoreInput"], typing.Dict[str, "MetadataValue"]]] @dataclass(frozen=True) class AIServerResponse__GetSimN(AIServerResponse): INDEX = 7 # type: int - value: typing.Sequence[ - typing.Tuple[ - typing.Optional["StoreInput"], - typing.Dict[str, "MetadataValue"], - "Similarity", - ] - ] + value: typing.Sequence[typing.Tuple[typing.Optional["StoreInput"], typing.Dict[str, "MetadataValue"], "Similarity"]] @dataclass(frozen=True) @@ -147,7 +143,6 @@ class AIServerResponse__CreateIndex(AIServerResponse): INDEX = 9 # type: int value: st.uint64 - AIServerResponse.VARIANTS = [ AIServerResponse__Unit, AIServerResponse__Pong, @@ -170,10 +165,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIServerResult) @staticmethod - def bincode_deserialize(input: bytes) -> "AIServerResult": + def bincode_deserialize(input: bytes) -> 'AIServerResult': v, buffer = bincode.deserialize(input, AIServerResult) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -188,10 +183,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIStoreInfo) @staticmethod - def bincode_deserialize(input: bytes) -> "AIStoreInfo": + def bincode_deserialize(input: bytes) -> 'AIStoreInfo': v, buffer = bincode.deserialize(input, AIStoreInfo) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -202,10 +197,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIStoreInputType) @staticmethod - def bincode_deserialize(input: bytes) -> "AIStoreInputType": + def bincode_deserialize(input: bytes) -> 'AIStoreInputType': v, buffer = bincode.deserialize(input, AIStoreInputType) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -220,7 +215,6 @@ class AIStoreInputType__Image(AIStoreInputType): INDEX = 1 # type: int pass - AIStoreInputType.VARIANTS = [ AIStoreInputType__RawString, AIStoreInputType__Image, @@ -236,10 +230,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ConnectedClient) @staticmethod - def bincode_deserialize(input: bytes) -> "ConnectedClient": + def bincode_deserialize(input: bytes) -> 'ConnectedClient': v, buffer = bincode.deserialize(input, ConnectedClient) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -250,10 +244,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, MetadataValue) @staticmethod - def bincode_deserialize(input: bytes) -> "MetadataValue": + def bincode_deserialize(input: bytes) -> 'MetadataValue': v, buffer = bincode.deserialize(input, MetadataValue) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -268,7 +262,6 @@ class MetadataValue__Image(MetadataValue): INDEX = 1 # type: int value: typing.Sequence[st.uint8] - MetadataValue.VARIANTS = [ MetadataValue__RawString, MetadataValue__Image, @@ -282,10 +275,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Result) @staticmethod - def bincode_deserialize(input: bytes) -> "Result": + def bincode_deserialize(input: bytes) -> 'Result': v, buffer = bincode.deserialize(input, Result) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -300,7 +293,6 @@ class Result__Err(Result): INDEX = 1 # type: int value: str - Result.VARIANTS = [ Result__Ok, Result__Err, @@ -319,10 +311,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerInfo) @staticmethod - def bincode_deserialize(input: bytes) -> "ServerInfo": + def bincode_deserialize(input: bytes) -> 'ServerInfo': v, buffer = bincode.deserialize(input, ServerInfo) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -333,10 +325,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerType) @staticmethod - def bincode_deserialize(input: bytes) -> "ServerType": + def bincode_deserialize(input: bytes) -> 'ServerType': v, buffer = bincode.deserialize(input, ServerType) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -351,7 +343,6 @@ class ServerType__AI(ServerType): INDEX = 1 # type: int pass - ServerType.VARIANTS = [ ServerType__Database, ServerType__AI, @@ -366,10 +357,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Similarity) @staticmethod - def bincode_deserialize(input: bytes) -> "Similarity": + def bincode_deserialize(input: bytes) -> 'Similarity': v, buffer = bincode.deserialize(input, Similarity) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -380,10 +371,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StoreInput) @staticmethod - def bincode_deserialize(input: bytes) -> "StoreInput": + def bincode_deserialize(input: bytes) -> 'StoreInput': v, buffer = bincode.deserialize(input, StoreInput) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -398,7 +389,6 @@ class StoreInput__Image(StoreInput): INDEX = 1 # type: int value: typing.Sequence[st.uint8] - StoreInput.VARIANTS = [ StoreInput__RawString, StoreInput__Image, @@ -414,10 +404,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StoreUpsert) @staticmethod - def bincode_deserialize(input: bytes) -> "StoreUpsert": + def bincode_deserialize(input: bytes) -> 'StoreUpsert': v, buffer = bincode.deserialize(input, StoreUpsert) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -430,10 +420,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, SystemTime) @staticmethod - def bincode_deserialize(input: bytes) -> "SystemTime": + def bincode_deserialize(input: bytes) -> 'SystemTime': v, buffer = bincode.deserialize(input, SystemTime) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -447,8 +437,9 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Version) @staticmethod - def bincode_deserialize(input: bytes) -> "Version": + def bincode_deserialize(input: bytes) -> 'Version': v, buffer = bincode.deserialize(input, Version) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v + diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/bincode/__init__.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/bincode/__init__.py index 38cbd7ff..4e5e0837 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/bincode/__init__.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/bincode/__init__.py @@ -1,16 +1,16 @@ # Copyright (c) Facebook, Inc. and its affiliates # SPDX-License-Identifier: MIT OR Apache-2.0 -import collections import dataclasses +import collections import io import struct import typing from copy import copy from typing import get_type_hints -from ahnlich_client_py.internals import serde_binary as sb from ahnlich_client_py.internals import serde_types as st +from ahnlich_client_py.internals import serde_binary as sb # Maximum length in practice for sequences (e.g. in Java). MAX_LENGTH = (1 << 31) - 1 diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_query.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_query.py index b281f346..b6120b2b 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_query.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_query.py @@ -1,10 +1,8 @@ # pyre-strict -import typing from dataclasses import dataclass - -from ahnlich_client_py.internals import bincode +import typing from ahnlich_client_py.internals import serde_types as st - +from ahnlich_client_py.internals import bincode class Algorithm: VARIANTS = [] # type: typing.Sequence[typing.Type[Algorithm]] @@ -13,10 +11,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Algorithm) @staticmethod - def bincode_deserialize(input: bytes) -> "Algorithm": + def bincode_deserialize(input: bytes) -> 'Algorithm': v, buffer = bincode.deserialize(input, Algorithm) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -43,7 +41,6 @@ class Algorithm__KDTree(Algorithm): INDEX = 3 # type: int pass - Algorithm.VARIANTS = [ Algorithm__EuclideanDistance, Algorithm__DotProductSimilarity, @@ -62,10 +59,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Array) @staticmethod - def bincode_deserialize(input: bytes) -> "Array": + def bincode_deserialize(input: bytes) -> 'Array': v, buffer = bincode.deserialize(input, Array) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -76,10 +73,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, MetadataValue) @staticmethod - def bincode_deserialize(input: bytes) -> "MetadataValue": + def bincode_deserialize(input: bytes) -> 'MetadataValue': v, buffer = bincode.deserialize(input, MetadataValue) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -94,7 +91,6 @@ class MetadataValue__Image(MetadataValue): INDEX = 1 # type: int value: typing.Sequence[st.uint8] - MetadataValue.VARIANTS = [ MetadataValue__RawString, MetadataValue__Image, @@ -108,10 +104,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, NonLinearAlgorithm) @staticmethod - def bincode_deserialize(input: bytes) -> "NonLinearAlgorithm": + def bincode_deserialize(input: bytes) -> 'NonLinearAlgorithm': v, buffer = bincode.deserialize(input, NonLinearAlgorithm) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -120,7 +116,6 @@ class NonLinearAlgorithm__KDTree(NonLinearAlgorithm): INDEX = 0 # type: int pass - NonLinearAlgorithm.VARIANTS = [ NonLinearAlgorithm__KDTree, ] @@ -133,10 +128,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Predicate) @staticmethod - def bincode_deserialize(input: bytes) -> "Predicate": + def bincode_deserialize(input: bytes) -> 'Predicate': v, buffer = bincode.deserialize(input, Predicate) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -167,7 +162,6 @@ class Predicate__NotIn(Predicate): key: str value: typing.Sequence["MetadataValue"] - Predicate.VARIANTS = [ Predicate__Equals, Predicate__NotEquals, @@ -183,10 +177,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, PredicateCondition) @staticmethod - def bincode_deserialize(input: bytes) -> "PredicateCondition": + def bincode_deserialize(input: bytes) -> 'PredicateCondition': v, buffer = bincode.deserialize(input, PredicateCondition) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -207,7 +201,6 @@ class PredicateCondition__Or(PredicateCondition): INDEX = 2 # type: int value: typing.Tuple["PredicateCondition", "PredicateCondition"] - PredicateCondition.VARIANTS = [ PredicateCondition__Value, PredicateCondition__And, @@ -222,10 +215,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Query) @staticmethod - def bincode_deserialize(input: bytes) -> "Query": + def bincode_deserialize(input: bytes) -> 'Query': v, buffer = bincode.deserialize(input, Query) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -344,7 +337,6 @@ class Query__Ping(Query): INDEX = 15 # type: int pass - Query.VARIANTS = [ Query__CreateStore, Query__GetKey, @@ -374,8 +366,9 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerQuery) @staticmethod - def bincode_deserialize(input: bytes) -> "ServerQuery": + def bincode_deserialize(input: bytes) -> 'ServerQuery': v, buffer = bincode.deserialize(input, ServerQuery) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v + diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_response.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_response.py index d1d0a6c4..acd3baa1 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_response.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_response.py @@ -1,10 +1,8 @@ # pyre-strict -import typing from dataclasses import dataclass - -from ahnlich_client_py.internals import bincode +import typing from ahnlich_client_py.internals import serde_types as st - +from ahnlich_client_py.internals import bincode @dataclass(frozen=True) class Array: @@ -16,10 +14,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Array) @staticmethod - def bincode_deserialize(input: bytes) -> "Array": + def bincode_deserialize(input: bytes) -> 'Array': v, buffer = bincode.deserialize(input, Array) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -32,10 +30,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ConnectedClient) @staticmethod - def bincode_deserialize(input: bytes) -> "ConnectedClient": + def bincode_deserialize(input: bytes) -> 'ConnectedClient': v, buffer = bincode.deserialize(input, ConnectedClient) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -46,10 +44,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, MetadataValue) @staticmethod - def bincode_deserialize(input: bytes) -> "MetadataValue": + def bincode_deserialize(input: bytes) -> 'MetadataValue': v, buffer = bincode.deserialize(input, MetadataValue) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -64,7 +62,6 @@ class MetadataValue__Image(MetadataValue): INDEX = 1 # type: int value: typing.Sequence[st.uint8] - MetadataValue.VARIANTS = [ MetadataValue__RawString, MetadataValue__Image, @@ -78,10 +75,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Result) @staticmethod - def bincode_deserialize(input: bytes) -> "Result": + def bincode_deserialize(input: bytes) -> 'Result': v, buffer = bincode.deserialize(input, Result) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -96,7 +93,6 @@ class Result__Err(Result): INDEX = 1 # type: int value: str - Result.VARIANTS = [ Result__Ok, Result__Err, @@ -115,10 +111,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerInfo) @staticmethod - def bincode_deserialize(input: bytes) -> "ServerInfo": + def bincode_deserialize(input: bytes) -> 'ServerInfo': v, buffer = bincode.deserialize(input, ServerInfo) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -129,10 +125,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerResponse) @staticmethod - def bincode_deserialize(input: bytes) -> "ServerResponse": + def bincode_deserialize(input: bytes) -> 'ServerResponse': v, buffer = bincode.deserialize(input, ServerResponse) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -181,9 +177,7 @@ class ServerResponse__Get(ServerResponse): @dataclass(frozen=True) class ServerResponse__GetSimN(ServerResponse): INDEX = 7 # type: int - value: typing.Sequence[ - typing.Tuple["Array", typing.Dict[str, "MetadataValue"], "Similarity"] - ] + value: typing.Sequence[typing.Tuple["Array", typing.Dict[str, "MetadataValue"], "Similarity"]] @dataclass(frozen=True) @@ -197,7 +191,6 @@ class ServerResponse__CreateIndex(ServerResponse): INDEX = 9 # type: int value: st.uint64 - ServerResponse.VARIANTS = [ ServerResponse__Unit, ServerResponse__Pong, @@ -220,10 +213,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerResult) @staticmethod - def bincode_deserialize(input: bytes) -> "ServerResult": + def bincode_deserialize(input: bytes) -> 'ServerResult': v, buffer = bincode.deserialize(input, ServerResult) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -234,10 +227,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerType) @staticmethod - def bincode_deserialize(input: bytes) -> "ServerType": + def bincode_deserialize(input: bytes) -> 'ServerType': v, buffer = bincode.deserialize(input, ServerType) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -252,7 +245,6 @@ class ServerType__AI(ServerType): INDEX = 1 # type: int pass - ServerType.VARIANTS = [ ServerType__Database, ServerType__AI, @@ -267,10 +259,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Similarity) @staticmethod - def bincode_deserialize(input: bytes) -> "Similarity": + def bincode_deserialize(input: bytes) -> 'Similarity': v, buffer = bincode.deserialize(input, Similarity) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -284,10 +276,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StoreInfo) @staticmethod - def bincode_deserialize(input: bytes) -> "StoreInfo": + def bincode_deserialize(input: bytes) -> 'StoreInfo': v, buffer = bincode.deserialize(input, StoreInfo) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -300,10 +292,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StoreUpsert) @staticmethod - def bincode_deserialize(input: bytes) -> "StoreUpsert": + def bincode_deserialize(input: bytes) -> 'StoreUpsert': v, buffer = bincode.deserialize(input, StoreUpsert) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -316,10 +308,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, SystemTime) @staticmethod - def bincode_deserialize(input: bytes) -> "SystemTime": + def bincode_deserialize(input: bytes) -> 'SystemTime': v, buffer = bincode.deserialize(input, SystemTime) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v @@ -333,8 +325,9 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Version) @staticmethod - def bincode_deserialize(input: bytes) -> "Version": + def bincode_deserialize(input: bytes) -> 'Version': v, buffer = bincode.deserialize(input, Version) if buffer: - raise st.DeserializationError("Some input bytes were not read") + raise st.DeserializationError("Some input bytes were not read"); return v + diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_binary/__init__.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_binary/__init__.py index a71b03f5..0730bd23 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_binary/__init__.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_binary/__init__.py @@ -7,8 +7,8 @@ Note: This internal module is currently only meant to share code between the BCS and bincode formats. Internal APIs could change in the future. """ -import collections import dataclasses +import collections import io import typing from typing import get_type_hints diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_types/__init__.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_types/__init__.py index 1c85909c..6d72f027 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_types/__init__.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_types/__init__.py @@ -1,10 +1,9 @@ # Copyright (c) Facebook, Inc. and its affiliates # SPDX-License-Identifier: MIT OR Apache-2.0 -import typing -from dataclasses import dataclass - import numpy as np +from dataclasses import dataclass +import typing class SerializationError(ValueError): diff --git a/sdk/ahnlich-client-py/demo_embed.py b/sdk/ahnlich-client-py/demo_embed.py index c9423b74..7d34e1b9 100644 --- a/sdk/ahnlich-client-py/demo_embed.py +++ b/sdk/ahnlich-client-py/demo_embed.py @@ -6,179 +6,293 @@ from ahnlich_client_py.clients import AhnlichAIClient from ahnlich_client_py.internals import ai_query -ai_store_payload_no_predicates = { - "store_name": "Diretnan Stores", - "query_model": ai_query.AIModel__AllMiniLML6V2(), - "index_model": ai_query.AIModel__AllMiniLML6V2(), -} - -ai_store_payload_with_predicates = { - "store_name": "Diretnan Predication Stores", - "query_model": ai_query.AIModel__AllMiniLML6V2(), - "index_model": ai_query.AIModel__AllMiniLML6V2(), - "predicates": ["special", "brand"], -} - -ai_store_payload_with_predicates_images = { - "store_name": "Diretnan Image Predication Stores", - "query_model": ai_query.AIModel__Resnet50(), - "index_model": ai_query.AIModel__Resnet50(), - "predicates": ["special", "brand"], -} - -ai_store_payload_with_predicates_images_texts = { - "store_name": "Diretnan Image Text Predication Stores", - "query_model": ai_query.AIModel__ClipVitB32Text(), - "index_model": ai_query.AIModel__ClipVitB32Image(), - "predicates": ["special", "brand"], -} - - -def run_insert_text(): - ai_client = AhnlichAIClient(address="127.0.0.1", port=1370, connect_timeout_sec=30) - store_inputs = [ - ( - ai_query.StoreInput__RawString("Jordan One"), - {"brand": ai_query.MetadataValue__RawString("Nike")}, - ), - ( - ai_query.StoreInput__RawString("Air Jordan"), - {"brand": ai_query.MetadataValue__RawString("Nike")}, - ), - ( - ai_query.StoreInput__RawString("Chicago Bulls"), - {"brand": ai_query.MetadataValue__RawString("NBA")}, - ), - ( - ai_query.StoreInput__RawString("Los Angeles Lakers"), - {"brand": ai_query.MetadataValue__RawString("NBA")}, - ), - ( - ai_query.StoreInput__RawString("Yeezey"), - {"brand": ai_query.MetadataValue__RawString("Adidas")}, - ), - ] - builder = ai_client.pipeline() - builder.create_store(**ai_store_payload_with_predicates) - builder.set( - store_name=ai_store_payload_with_predicates["store_name"], - inputs=store_inputs, - preprocess_action=ai_query.PreprocessAction__RawString( - ai_query.StringAction__ErrorIfTokensExceed() - ), - ) - return builder.exec() - - -def run_get_simn_text(): - ai_client = AhnlichAIClient(address="127.0.0.1", port=1370) - builder = ai_client.pipeline() - builder.get_sim_n( - store_name=ai_store_payload_with_predicates["store_name"], - search_input=ai_query.StoreInput__RawString("Basketball"), - closest_n=3, - algorithm=ai_query.Algorithm__CosineSimilarity(), - ) - return builder.exec() - - -def insert_image(urls, store_data): - ai_client = AhnlichAIClient(address="127.0.0.1", port=1370, connect_timeout_sec=30) - builder = ai_client.pipeline() - builder.create_store(**store_data) - for url, brand in urls: - print("Processing image: ", url) - if url.startswith("http"): - location = urlopen(url) - else: - location = url - image = Image.open(location) - buffer = io.BytesIO() - image.save(buffer, format=image.format) - store_inputs = [ +class Text2TextDemo: + def __init__(self): + ai_client = AhnlichAIClient( + address="127.0.0.1", port=1370, connect_timeout_sec=30 + ) + self.query_model = ai_query.AIModel__AllMiniLML6V2() + self.index_model = ai_query.AIModel__AllMiniLML6V2() + self.store_name = "The Sports Press Club" + self.builder = ai_client.pipeline() + predicates = ["sport"] + self.builder.create_store( + store_name=self.store_name, + query_model=self.query_model, + index_model=self.index_model, + predicates=predicates, + ) + + def insert(self): + # Initial list of tuples (snippet, sport) + snippets_and_sports = [ + ( + "Manchester City secures a thrilling 2-1 victory over Liverpool in the Premier League, " + "with Erling Haaland scoring the decisive goal in the 87th minute.", + "Football", + ), ( - ai_query.StoreInput__Image(buffer.getvalue()), - {"brand": ai_query.MetadataValue__RawString(brand)}, + "Coco Gauff clinches a hard-fought victory in a gripping three-set final against Iga Swiatek " + "to win the Wimbledon Finals, solidifying her place among the top competitors.", + "Tennis", ), + ( + "LeBron James makes history yet again, becoming the NBA's all-time leading scorer in a single " + "season as the Lakers defeat the Golden State Warriors 120-115.", + "Basketball", + ), + ( + "India edges out Australia in a nail-biting T20 match, with Virat Kohli's unbeaten 78 " + "guiding the team to a thrilling last-over victory.", + "Cricket", + ), + ( + "Max Verstappen dominates the Abu Dhabi Grand Prix, achieving an incredible 16th win " + "of the season, a milestone that underscores his unparalleled dominance and secures his third " + "consecutive championship title.", + "Formula 1", + ), + ] + + store_inputs = [ + ( + ai_query.StoreInput__RawString(snippet), + {"sport": ai_query.MetadataValue__RawString(sport)}, + ) + for snippet, sport in snippets_and_sports ] - builder.set( - store_name=store_data["store_name"], + self.builder.set( + store_name=self.store_name, inputs=store_inputs, - preprocess_action=ai_query.PreprocessAction__Image( - ai_query.ImageAction__ResizeImage() - ), + preprocess_action=ai_query.PreprocessAction__ModelPreprocessing(), + ) + return self.builder.exec() + + def query(self): + search_input = "News events where athletes broke a record" + self.builder.get_sim_n( + store_name=self.store_name, + search_input=ai_query.StoreInput__RawString(search_input), + closest_n=3, + algorithm=ai_query.Algorithm__CosineSimilarity(), + ) + return self.builder.exec() + + +class VeryShortText2TextDemo: + def __init__(self): + ai_client = AhnlichAIClient( + address="127.0.0.1", port=1370, connect_timeout_sec=30 + ) + self.query_model = ai_query.AIModel__ClipVitB32Text() + self.index_model = ai_query.AIModel__ClipVitB32Text() + self.store_name = "The Literary Collection" + self.builder = ai_client.pipeline() + predicates = ["citizenship"] + self.builder.create_store( + store_name=self.store_name, + query_model=self.query_model, + index_model=self.index_model, + predicates=predicates, + ) + + def insert(self): + # Initial list of tuples (snippet, writer's citizenship) + snippets_and_citizenship = [ + ("1984", "English"), + ("Things Fall Apart", "Nigerian"), + ("The Great Gatsby", "American"), + ("The Alchemist", "Brazilian"), + ("Man's Search for Meaning", "Austrian"), + ] + + # Create store_inputs using a list comprehension + store_inputs = [ + ( + ai_query.StoreInput__RawString(snippet), + {"citizenship": ai_query.MetadataValue__RawString(citizenship)}, + ) + for snippet, citizenship in snippets_and_citizenship + ] + + self.builder.set( + store_name=self.store_name, + inputs=store_inputs, + preprocess_action=ai_query.PreprocessAction__ModelPreprocessing(), ) + return self.builder.exec() - return builder.exec() - -def run_insert_image(): - image_urls = [ - ( - "https://cdn.britannica.com/96/195196-050-3909D5BD/Michael-Jordan-1988.jpg", - "Slam Dunk Jordan", - ), - ("https://i.ebayimg.com/images/g/0-wAAOSwsQ1h5Pqc/s-l1600.webp", "Air Jordan"), - ( - "https://as2.ftcdn.net/v2/jpg/02/70/86/51/1000_F_270865104_HMpmjP3Hqt0MvdlV7QkQJful50bBzj46.jpg", - "Aeroplane", - ), - ( - "https://csaenvironmental.co.uk/wp-content/uploads/2020/06/landscape-value-600x325.jpg", - "Landscape", - ), - ] - return insert_image(image_urls, ai_store_payload_with_predicates_images) - - -def run_get_simn_image(): - ai_client = AhnlichAIClient(address="127.0.0.1", port=1370) - builder = ai_client.pipeline() - url = "https://i.pinimg.com/564x/9d/76/c8/9d76c8229b7528643d69636c1a9a428d.jpg" - image = Image.open(urlopen(url)) + def query(self): + search_input = "Chinua Achebe" + self.builder.get_sim_n( + store_name=self.store_name, + search_input=ai_query.StoreInput__RawString(search_input), + closest_n=3, + algorithm=ai_query.Algorithm__CosineSimilarity(), + ) + return self.builder.exec() + + +def url_to_buffer(url): + """ + Converts an image URL or local file path to a buffer value. + :param url: URL or file path of the image. + :return: BytesIO buffer containing the image data. + """ + print(f"Processing image: {url}") + if url.startswith("http"): + location = urlopen(url) + else: + location = url + + image = Image.open(location) buffer = io.BytesIO() image.save(buffer, format=image.format) - builder.get_sim_n( - store_name=ai_store_payload_with_predicates_images["store_name"], - search_input=ai_query.StoreInput__Image(buffer.getvalue()), - closest_n=3, - algorithm=ai_query.Algorithm__CosineSimilarity(), - ) - return builder.exec() - - -def run_insert_image_text(): - image_urls = [ - ( - "https://imageio.forbes.com/specials-images/imageserve/632357fbf1cebc1639065099/Roger-Federer-celebrated" - "-after-beating-Lorenzo-Sonego-at-Wimbledon-last-year-/1960x0.jpg?format=jpg&width=960", - "Roger Federer", - ), - ("https://www.silverarrows.net/wp-content/uploads/2020/05/Lewis-Hamilton-Japan.jpg", "Lewis Hamilton"), - ( - "https://img.20mn.fr/B2Dto_H3RveJTzabY4IR2yk/1444x920_andreja-laski-of-team-slovenia-and-clarisse-agbegnenou" - "-team-france-compete-during-the-women-63-kg-semifinal-of-table-b-contest-on-day-four-of-the-olympic-games-" - "paris-2024-at-champs-de-mars-arena-03vulaurent-d2317-credit-laurent-vu-sipa-2407301738", - "Clarisse Agbegnenou", - ), - ( - "https://c8.alamy.com/comp/R1YEE4/london-uk-15th-november-2018-jadon-sancho-of-england-is-tackled-by-" - "christian-pulisic-of-usa-during-the-international-friendly-match-between-england-and-usa-at-wembley-" - "stadium-on-november-15th-2018-in-london-england-photo-by-matt-bradshawphcimages-credit-phc-imagesalamy-live-news-R1YEE4.jpg", - "Christian Pulisic and Sancho", - ), - ] - return insert_image(image_urls, ai_store_payload_with_predicates_images_texts) - - -def run_get_simn_image_text(): - ai_client = AhnlichAIClient(address="127.0.0.1", port=1370) - builder = ai_client.pipeline() - builder.get_sim_n( - store_name=ai_store_payload_with_predicates_images_texts["store_name"], - search_input=ai_query.StoreInput__RawString("United States vs England"), - closest_n=3, - algorithm=ai_query.Algorithm__CosineSimilarity(), - ) - return builder.exec() + buffer.seek(0) # Reset the buffer pointer to the beginning + return buffer + + +class Text2ImageDemo: + def __init__(self): + ai_client = AhnlichAIClient( + address="127.0.0.1", port=1370, connect_timeout_sec=30 + ) + self.query_model = ai_query.AIModel__ClipVitB32Text() + self.index_model = ai_query.AIModel__ClipVitB32Image() + self.store_name = "The Sports Image Collection" + self.builder = ai_client.pipeline() + predicates = ["athlete"] + self.builder.create_store( + store_name=self.store_name, + query_model=self.query_model, + index_model=self.index_model, + predicates=predicates, + store_original=False, + ) + + def insert(self): + # Initial list of tuples (image URL, athlete name) + image_urls_and_athletes = [ + ( + "https://imageio.forbes.com/specials-images/imageserve/632357fbf1cebc1639065099/Roger-Federer-celebrated" + "-after-beating-Lorenzo-Sonego-at-Wimbledon-last-year-/1960x0.jpg?format=jpg&width=960", + "Roger Federer", + ), + ( + "https://www.silverarrows.net/wp-content/uploads/2020/05/Lewis-Hamilton-Japan.jpg", + "Lewis Hamilton", + ), + ( + "https://img.20mn.fr/B2Dto_H3RveJTzabY4IR2yk/1444x920_andreja-laski-of-team-slovenia-and-clarisse-agbegnenou" + "-team-france-compete-during-the-women-63-kg-semifinal-of-table-b-contest-on-day-four-of-the-olympic-games-" + "paris-2024-at-champs-de-mars-arena-03vulaurent-d2317-credit-laurent-vu-sipa-2407301738", + "Clarisse Agbegnenou", + ), + ( + "https://c8.alamy.com/comp/R1YEE4/london-uk-15th-november-2018-jadon-sancho-of-england-is-tackled-by-" + "christian-pulisic-of-usa-during-the-international-friendly-match-between-england-and-usa-at-wembley-" + "stadium-on-november-15th-2018-in-london-england-photo-by-matt-bradshawphcimages-credit-phc-imagesalamy-live-news-R1YEE4.jpg", + "Christian Pulisic and Sancho", + ), + ] + + # Process images and create store_inputs + store_inputs = [ + ( + ai_query.StoreInput__Image(url_to_buffer(url).getvalue()), + {"brand": ai_query.MetadataValue__RawString(athlete)}, + ) + for url, athlete in image_urls_and_athletes + ] + + # Set the store inputs + self.builder.set( + store_name=self.store_name, + inputs=store_inputs, + preprocess_action=ai_query.PreprocessAction__ModelPreprocessing(), + ) + return self.builder.exec() + + def query(self): + search_input = "United States vs England" + self.builder.get_sim_n( + store_name=self.store_name, + search_input=ai_query.StoreInput__RawString(search_input), + closest_n=3, + algorithm=ai_query.Algorithm__CosineSimilarity(), + ) + return self.builder.exec() + + +class Image2ImageDemo: + def __init__(self): + ai_client = AhnlichAIClient( + address="127.0.0.1", port=1370, connect_timeout_sec=30 + ) + self.query_model = ai_query.AIModel__ClipVitB32Image() + self.index_model = ai_query.AIModel__ClipVitB32Image() + self.store_name = "The Jordan or Not Jordan Collection" + self.builder = ai_client.pipeline() + predicates = ["label"] + self.builder.create_store( + store_name=self.store_name, + query_model=self.query_model, + index_model=self.index_model, + predicates=predicates, + store_original=False, + ) + + def insert(self): + # Initial list of tuples (image URL, image label) + image_urls_and_labels = [ + ( + "https://cdn.britannica.com/96/195196-050-3909D5BD/Michael-Jordan-1988.jpg", + "Slam Dunk Jordan", + ), + ( + "https://i.ebayimg.com/images/g/0-wAAOSwsQ1h5Pqc/s-l1600.webp", + "Air Jordan", + ), + ( + "https://as2.ftcdn.net/v2/jpg/02/70/86/51/1000_F_270865104_HMpmjP3Hqt0MvdlV7QkQJful50bBzj46.jpg", + "Aeroplane", + ), + ( + "https://csaenvironmental.co.uk/wp-content/uploads/2020/06/landscape-value-600x325.jpg", + "Landscape", + ), + ] + + # Process images and create store_inputs + store_inputs = [ + ( + ai_query.StoreInput__Image(url_to_buffer(url).getvalue()), + {"label": ai_query.MetadataValue__RawString(label)}, + ) + for url, label in image_urls_and_labels + ] + + # Set the store inputs + self.builder.set( + store_name=self.store_name, + inputs=store_inputs, + preprocess_action=ai_query.PreprocessAction__ModelPreprocessing(), + ) + return self.builder.exec() + + def query(self): + # Query with an image + query_url = ( + "https://i.pinimg.com/564x/9d/76/c8/9d76c8229b7528643d69636c1a9a428d.jpg" + ) + buffer = url_to_buffer(query_url) + + self.builder.get_sim_n( + store_name=self.store_name, + search_input=ai_query.StoreInput__Image(buffer.getvalue()), + closest_n=3, + algorithm=ai_query.Algorithm__CosineSimilarity(), + ) + return self.builder.exec() diff --git a/type_specs/query/ai_query.json b/type_specs/query/ai_query.json index d8184808..8325591b 100644 --- a/type_specs/query/ai_query.json +++ b/type_specs/query/ai_query.json @@ -18,6 +18,9 @@ }, "5": { "ClipVitB32Image": "UNIT" + }, + "6": { + "ClipVitB32Text": "UNIT" } } }, @@ -284,16 +287,6 @@ } } }, - "ImageAction": { - "ENUM": { - "0": { - "ResizeImage": "UNIT" - }, - "1": { - "ErrorIfDimensionsMismatch": "UNIT" - } - } - }, "MetadataValue": { "ENUM": { "0": { @@ -419,18 +412,10 @@ "PreprocessAction": { "ENUM": { "0": { - "RawString": { - "NEWTYPE": { - "TYPENAME": "StringAction" - } - } + "NoPreprocessing": "UNIT" }, "1": { - "Image": { - "NEWTYPE": { - "TYPENAME": "ImageAction" - } - } + "ModelPreprocessing": "UNIT" } } }, @@ -449,15 +434,5 @@ } } } - }, - "StringAction": { - "ENUM": { - "0": { - "TruncateIfTokensExceed": "UNIT" - }, - "1": { - "ErrorIfTokensExceed": "UNIT" - } - } } } \ No newline at end of file diff --git a/type_specs/response/ai_response.json b/type_specs/response/ai_response.json index 997acf1c..de1e29b2 100644 --- a/type_specs/response/ai_response.json +++ b/type_specs/response/ai_response.json @@ -18,6 +18,9 @@ }, "5": { "ClipVitB32Image": "UNIT" + }, + "6": { + "ClipVitB32Text": "UNIT" } } }, From cf5ae28a28b4c3e6a2c50bfc7fe6c0b825010f4e Mon Sep 17 00:00:00 2001 From: HabeebShopeju Date: Tue, 26 Nov 2024 11:23:26 +0000 Subject: [PATCH 08/15] Removed fastembed --- ahnlich/Cargo.lock | 51 ++----------------------------------------- ahnlich/ai/Cargo.toml | 3 +-- 2 files changed, 3 insertions(+), 51 deletions(-) diff --git a/ahnlich/Cargo.lock b/ahnlich/Cargo.lock index 8a7ca0ed..82472869 100644 --- a/ahnlich/Cargo.lock +++ b/ahnlich/Cargo.lock @@ -102,7 +102,6 @@ dependencies = [ "deadpool", "dirs", "fallible_collections", - "fastembed", "flurry", "futures", "hf-hub", @@ -123,7 +122,7 @@ dependencies = [ "termcolor", "thiserror", "tiktoken-rs", - "tokenizers 0.20.1", + "tokenizers", "tokio", "tokio-util", "tracer", @@ -1153,22 +1152,6 @@ dependencies = [ "regex", ] -[[package]] -name = "fastembed" -version = "4.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14916e4d447b8a8f1c62e0cf720bada980c27e9371c44b52b79e0336df97185d" -dependencies = [ - "anyhow", - "hf-hub", - "image", - "ndarray", - "ort", - "rayon", - "serde_json", - "tokenizers 0.19.1", -] - [[package]] name = "fastrand" version = "2.1.1" @@ -2324,6 +2307,7 @@ version = "2.0.0-rc.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "45e45a172e6c0fb7d640e92c7740f4ea476bfc49ef5c52ea9c73e9fae32b09fe" dependencies = [ + "half", "ndarray", "ort-sys", "thiserror", @@ -3483,37 +3467,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" -[[package]] -name = "tokenizers" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e500fad1dd3af3d626327e6a3fe5050e664a6eaa4708b8ca92f1794aaf73e6fd" -dependencies = [ - "aho-corasick", - "derive_builder", - "esaxx-rs", - "getrandom", - "itertools 0.12.1", - "lazy_static", - "log", - "macro_rules_attribute", - "monostate", - "onig", - "paste", - "rand", - "rayon", - "rayon-cond", - "regex", - "regex-syntax 0.8.5", - "serde", - "serde_json", - "spm_precompiled", - "thiserror", - "unicode-normalization-alignments", - "unicode-segmentation", - "unicode_categories", -] - [[package]] name = "tokenizers" version = "0.20.1" diff --git a/ahnlich/ai/Cargo.toml b/ahnlich/ai/Cargo.toml index 1a38bf7c..32381941 100644 --- a/ahnlich/ai/Cargo.toml +++ b/ahnlich/ai/Cargo.toml @@ -38,10 +38,9 @@ strum = { version = "0.26", features = ["derive"] } log.workspace = true fallible_collections.workspace = true rayon.workspace = true -fastembed = { version = "4.0.0", features = ["default"] } hf-hub = { version = "0.3", default-features = false } dirs = "5.0.1" -ort = { version = "=2.0.0-rc.5", default-features = false } +ort = { version = "=2.0.0-rc.5", features = ["ndarray"] } moka = { version = "0.12.8", features = ["future"] } tracing-opentelemetry.workspace = true futures.workspace = true From 58a657c4d63ea3737af9b13e402d8f0b13d78aad Mon Sep 17 00:00:00 2001 From: Diretnan Domnan Date: Tue, 26 Nov 2024 20:38:58 +0100 Subject: [PATCH 09/15] Got rid of some mutexing as we want to lock as little as possible --- ahnlich/ai/src/engine/ai/providers/ort.rs | 5 +- .../processors/imagearray_to_ndarray.rs | 10 +- .../engine/ai/providers/processors/pooling.rs | 54 +++++------ .../ai/providers/processors/postprocessor.rs | 54 ++++++----- .../ai/providers/processors/preprocessor.rs | 92 ++++++++++--------- .../ai/providers/processors/tokenize.rs | 20 ++-- 6 files changed, 121 insertions(+), 114 deletions(-) diff --git a/ahnlich/ai/src/engine/ai/providers/ort.rs b/ahnlich/ai/src/engine/ai/providers/ort.rs index c2acdeb9..e75300b9 100644 --- a/ahnlich/ai/src/engine/ai/providers/ort.rs +++ b/ahnlich/ai/src/engine/ai/providers/ort.rs @@ -289,7 +289,8 @@ impl ORTProvider { ids_array.extend(ids.iter().map(|x| *x as i64)); mask_array.extend(mask.iter().map(|x| *x as i64)); if let Some(ref mut token_type_ids_array) = token_type_ids_array { - token_type_ids_array.extend(encoding.get_type_ids().iter().map(|x| *x as i64)); + token_type_ids_array + .extend(encoding.get_type_ids().iter().map(|x| *x as i64)); } }); @@ -383,7 +384,7 @@ impl ProviderTrait for ORTProvider { self.model = Some(ORTModel::Image(ORTImageModel { repo_name, weights_file, - session: Some(session) + session: Some(session), })); let preprocessor = ORTImagePreprocessor::load(self.supported_models.unwrap(), model_repo)?; diff --git a/ahnlich/ai/src/engine/ai/providers/processors/imagearray_to_ndarray.rs b/ahnlich/ai/src/engine/ai/providers/processors/imagearray_to_ndarray.rs index 14bcc092..65b6b223 100644 --- a/ahnlich/ai/src/engine/ai/providers/processors/imagearray_to_ndarray.rs +++ b/ahnlich/ai/src/engine/ai/providers/processors/imagearray_to_ndarray.rs @@ -1,6 +1,5 @@ use crate::engine::ai::providers::processors::{Preprocessor, PreprocessorData}; use crate::error::AIProxyError; -use std::sync::Mutex; pub struct ImageArrayToNdArray; @@ -8,20 +7,23 @@ impl Preprocessor for ImageArrayToNdArray { fn process(&self, data: PreprocessorData) -> Result { match data { PreprocessorData::ImageArray(mut arrays) => { - let array_shapes = Mutex::new(vec![]); // Not using par_iter_mut here because it messes up the order of the images + // TODO: Figure out if it's more expensive to use par_iter_mut with enumerate or + // just keep doing it sequentially let array_views = arrays .iter_mut() .map(|image_arr| { image_arr.onnx_transform(); - array_shapes.lock().unwrap().push(image_arr.image_dim()); image_arr.view() }) .collect::>(); - let array_shapes = array_shapes.into_inner().unwrap(); let pixel_values_array = ndarray::stack(ndarray::Axis(0), &array_views).map_err(|_| { + let array_shapes = arrays + .iter() + .map(|image| image.image_dim()) + .collect::>(); AIProxyError::ImageArrayToNdArrayError { message: format!( "Images must have same dimensions, instead found: {:?}.", diff --git a/ahnlich/ai/src/engine/ai/providers/processors/pooling.rs b/ahnlich/ai/src/engine/ai/providers/processors/pooling.rs index ea4c8728..0d2be85b 100644 --- a/ahnlich/ai/src/engine/ai/providers/processors/pooling.rs +++ b/ahnlich/ai/src/engine/ai/providers/processors/pooling.rs @@ -2,11 +2,13 @@ use crate::engine::ai::providers::processors::{Postprocessor, PostprocessorData} use crate::error::AIProxyError; use ndarray::{s, Array, Axis, Ix2}; +#[derive(Clone)] pub enum Pooling { Regular(RegularPooling), - Mean(MeanPooling), + Mean(MeanPoolingBuilder), } +#[derive(Copy, Clone, Default)] pub struct RegularPooling; impl Postprocessor for RegularPooling { @@ -24,44 +26,38 @@ impl Postprocessor for RegularPooling { } } -#[derive(Default)] -pub struct MeanPooling { - attention_mask: Option>, -} +#[derive(Clone, Default)] +pub struct MeanPoolingBuilder; -impl MeanPooling { - pub fn new() -> Self { - Self { - attention_mask: None, - } +impl MeanPoolingBuilder { + pub fn with_attention_mask<'a>(&'a self, attention_mask: Array) -> MeanPooling { + MeanPooling { attention_mask } } +} - pub fn set_attention_mask(&mut self, attention_mask: Option>) { - self.attention_mask = attention_mask; - } +#[derive(Clone, Default)] +pub struct MeanPooling { + attention_mask: Array, } impl Postprocessor for MeanPooling { fn process(&self, data: PostprocessorData) -> Result { match data { PostprocessorData::NdArray3(array) => { - let attention_mask = match &self.attention_mask { - Some(mask) => { - let attention_mask = mask.mapv(|x| x as f32); - attention_mask - .insert_axis(Axis(2)) - .broadcast(array.dim()) - .ok_or(AIProxyError::PoolingError { - message: format!( - "Could not broadcast attention mask with shape {:?} to \ + let attention_mask = { + let attention_mask = self.attention_mask.mapv(|x| x as f32); + attention_mask + .insert_axis(Axis(2)) + .broadcast(array.dim()) + .ok_or(AIProxyError::PoolingError { + message: format!( + "Could not broadcast attention mask with shape {:?} to \ shape {:?} of the input tensor.", - mask.shape(), - array.shape() - ), - })? - .to_owned() - } - None => Array::ones(array.dim()), + self.attention_mask.shape(), + array.shape() + ), + })? + .to_owned() }; let masked_array = &attention_mask * &array; diff --git a/ahnlich/ai/src/engine/ai/providers/processors/postprocessor.rs b/ahnlich/ai/src/engine/ai/providers/processors/postprocessor.rs index bbe28018..54f580f2 100644 --- a/ahnlich/ai/src/engine/ai/providers/processors/postprocessor.rs +++ b/ahnlich/ai/src/engine/ai/providers/processors/postprocessor.rs @@ -6,7 +6,8 @@ use crate::engine::ai::providers::processors::{Postprocessor, PostprocessorData} use crate::error::AIProxyError; use ndarray::{Array, Ix2}; use ort::SessionOutputs; -use std::sync::{Arc, Mutex}; + +use super::pooling::MeanPoolingBuilder; pub enum ORTPostprocessor { Image(ORTImagePostprocessor), @@ -16,7 +17,7 @@ pub enum ORTPostprocessor { pub struct ORTTextPostprocessor { model: SupportedModels, onnx_output_transform: OnnxOutputTransform, - pooling: Arc>, + pooling: Pooling, normalize: Option, } @@ -35,24 +36,26 @@ impl ORTTextPostprocessor { message: "Unsupported model for ORTTextPostprocessor".to_string(), })?, }; - let ops = match supported_model { + let (pooling, normalize) = match supported_model { SupportedModels::AllMiniLML6V2 | SupportedModels::AllMiniLML12V2 => { - Ok((Pooling::Mean(MeanPooling::new()), Some(VectorNormalize))) + (Pooling::Mean(MeanPoolingBuilder), Some(VectorNormalize)) } SupportedModels::BGEBaseEnV15 | SupportedModels::BGELargeEnV15 => { - Ok((Pooling::Regular(RegularPooling), Some(VectorNormalize))) + (Pooling::Regular(RegularPooling), Some(VectorNormalize)) } - SupportedModels::ClipVitB32Text => Ok((Pooling::Mean(MeanPooling::new()), None)), - _ => Err(AIProxyError::ModelPostprocessingError { - model_name: supported_model.to_string(), - message: "Unsupported model for ORTTextPostprocessor".to_string(), - }), - }?; + SupportedModels::ClipVitB32Text => (Pooling::Mean(MeanPoolingBuilder), None), + _ => { + return Err(AIProxyError::ModelPostprocessingError { + model_name: supported_model.to_string(), + message: "Unsupported model for ORTTextPostprocessor".to_string(), + }) + } + }; Ok(Self { model: supported_model, onnx_output_transform: output_transform, - pooling: Arc::new(Mutex::new(ops.0)), - normalize: ops.1, + pooling, + normalize, }) } @@ -64,19 +67,15 @@ impl ORTTextPostprocessor { let embeddings = self .onnx_output_transform .process(PostprocessorData::OnnxOutput(session_outputs))?; - let mut pooling = - self.pooling - .lock() - .map_err(|_| AIProxyError::ModelPostprocessingError { - model_name: self.model.to_string(), - message: "Failed to acquire lock on pooling.".to_string(), - })?; - let pooled = match &mut *pooling { - Pooling::Regular(pooling) => pooling.process(embeddings)?, - Pooling::Mean(pooling) => { - pooling.set_attention_mask(Some(attention_mask)); - pooling.process(embeddings)? + let pooling_impl = match &self.pooling { + Pooling::Mean(ref pooling) => { + PoolingImpl::Mean(pooling.with_attention_mask(attention_mask)) } + Pooling::Regular(a) => PoolingImpl::Regular(*a), + }; + let pooled = match pooling_impl { + PoolingImpl::Regular(ref pooling) => pooling.process(embeddings)?, + PoolingImpl::Mean(ref pooling) => pooling.process(embeddings)?, }; let result = match &self.normalize { Some(normalize) => normalize.process(pooled), @@ -92,6 +91,11 @@ impl ORTTextPostprocessor { } } +enum PoolingImpl { + Regular(RegularPooling), + Mean(MeanPooling), +} + pub struct ORTImagePostprocessor { model: SupportedModels, onnx_output_transform: OnnxOutputTransform, diff --git a/ahnlich/ai/src/engine/ai/providers/processors/preprocessor.rs b/ahnlich/ai/src/engine/ai/providers/processors/preprocessor.rs index a7bfe5d2..76a54196 100644 --- a/ahnlich/ai/src/engine/ai/providers/processors/preprocessor.rs +++ b/ahnlich/ai/src/engine/ai/providers/processors/preprocessor.rs @@ -56,49 +56,56 @@ impl ORTImagePreprocessor { pub fn process(&self, data: Vec) -> Result, AIProxyError> { let mut data = PreprocessorData::ImageArray(data); data = match self.resize { - Some(ref resize) => resize.process(data).map_err( - |e| AIProxyError::ModelPreprocessingError { - model_name: self.model.to_string(), - message: format!("Failed to process resize: {}", e), - }, - )?, - None => data, - }; - - data = match self.center_crop { - Some(ref center_crop) => center_crop.process(data).map_err( - |e| AIProxyError::ModelPreprocessingError { - model_name: self.model.to_string(), - message: format!("Failed to process center crop: {}", e), - }, - )?, + Some(ref resize) => { + resize + .process(data) + .map_err(|e| AIProxyError::ModelPreprocessingError { + model_name: self.model.to_string(), + message: format!("Failed to process resize: {}", e), + })? + } None => data, }; - data = self.imagearray_to_ndarray.process(data).map_err( - |e| AIProxyError::ModelPreprocessingError { + data = + match self.center_crop { + Some(ref center_crop) => center_crop.process(data).map_err(|e| { + AIProxyError::ModelPreprocessingError { + model_name: self.model.to_string(), + message: format!("Failed to process center crop: {}", e), + } + })?, + None => data, + }; + + data = self.imagearray_to_ndarray.process(data).map_err(|e| { + AIProxyError::ModelPreprocessingError { model_name: self.model.to_string(), message: format!("Failed to process imagearray to ndarray: {}", e), - }, - )?; + } + })?; data = match self.rescale { - Some(ref rescale) => rescale.process(data).map_err( - |e| AIProxyError::ModelPreprocessingError { - model_name: self.model.to_string(), - message: format!("Failed to process rescale: {}", e), - }, - )?, + Some(ref rescale) => { + rescale + .process(data) + .map_err(|e| AIProxyError::ModelPreprocessingError { + model_name: self.model.to_string(), + message: format!("Failed to process rescale: {}", e), + })? + } None => data, }; data = match self.normalize { - Some(ref normalize) => normalize.process(data).map_err( - |e| AIProxyError::ModelPreprocessingError { - model_name: self.model.to_string(), - message: format!("Failed to process normalize: {}", e), - }, - )?, + Some(ref normalize) => { + normalize + .process(data) + .map_err(|e| AIProxyError::ModelPreprocessingError { + model_name: self.model.to_string(), + message: format!("Failed to process normalize: {}", e), + })? + } None => data, }; @@ -144,19 +151,20 @@ impl ORTTextPreprocessor { truncate: bool, ) -> Result, AIProxyError> { let mut data = PreprocessorData::Text(data); - let mut tokenize = self.tokenize.lock().map_err(|_| { - AIProxyError::ModelPreprocessingError { - model_name: self.model.to_string(), - message: "Failed to acquire lock on tokenize.".to_string(), - } - })?; + let mut tokenize = + self.tokenize + .lock() + .map_err(|_| AIProxyError::ModelPreprocessingError { + model_name: self.model.to_string(), + message: "Failed to acquire lock on tokenize.".to_string(), + })?; let _ = tokenize.set_truncate(truncate); - data = tokenize.process(data).map_err( - |e| AIProxyError::ModelPreprocessingError { + data = tokenize + .process(data) + .map_err(|e| AIProxyError::ModelPreprocessingError { model_name: self.model.to_string(), message: format!("Failed to process tokenize: {}", e), - }, - )?; + })?; match data { PreprocessorData::EncodedText(encodings) => Ok(encodings), diff --git a/ahnlich/ai/src/engine/ai/providers/processors/tokenize.rs b/ahnlich/ai/src/engine/ai/providers/processors/tokenize.rs index bf1eaf10..6a598460 100644 --- a/ahnlich/ai/src/engine/ai/providers/processors/tokenize.rs +++ b/ahnlich/ai/src/engine/ai/providers/processors/tokenize.rs @@ -28,11 +28,7 @@ impl Tokenize { &model_repo .get(&tokenizer_files.tokenizer_file) .map_err(|e| AIProxyError::ModelConfigLoadError { - message: format!( - "failed to fetch {}, {}", - &tokenizer_files.tokenizer_file, - e - ), + message: format!("failed to fetch {}, {}", &tokenizer_files.tokenizer_file, e), })?, )?; let mut config_reader = HFConfigReader::new(model_repo); @@ -122,13 +118,13 @@ impl Tokenize { pub fn set_truncate(&mut self, truncate: bool) -> Result<(), AIProxyError> { let tokenizer = if truncate { self.tokenizer - .with_truncation(Some(TruncationParams { - max_length: self.model_max_length, - ..Default::default() - })) - .map_err(|_| AIProxyError::ModelTokenizerLoadError { - message: "Error setting truncation params.".to_string(), - })? + .with_truncation(Some(TruncationParams { + max_length: self.model_max_length, + ..Default::default() + })) + .map_err(|_| AIProxyError::ModelTokenizerLoadError { + message: "Error setting truncation params.".to_string(), + })? } else { self.tokenizer.with_truncation(None).map_err(|_| { AIProxyError::ModelTokenizerLoadError { From 9f5273743f4ec1797b45882c9387df8fafe145de Mon Sep 17 00:00:00 2001 From: Diretnan Domnan Date: Tue, 26 Nov 2024 21:34:40 +0100 Subject: [PATCH 10/15] Fix OnnxOutputTransform for Resnet50 --- .../ai/providers/processors/onnx_output_transform.rs | 6 +++--- .../ai/src/engine/ai/providers/processors/pooling.rs | 2 +- .../engine/ai/providers/processors/postprocessor.rs | 11 ++++------- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/ahnlich/ai/src/engine/ai/providers/processors/onnx_output_transform.rs b/ahnlich/ai/src/engine/ai/providers/processors/onnx_output_transform.rs index 954d1399..ff84d7b5 100644 --- a/ahnlich/ai/src/engine/ai/providers/processors/onnx_output_transform.rs +++ b/ahnlich/ai/src/engine/ai/providers/processors/onnx_output_transform.rs @@ -3,11 +3,11 @@ use crate::error::AIProxyError; use ndarray::{Ix2, Ix3}; pub struct OnnxOutputTransform { - output_key: String, + output_key: &'static str, } impl OnnxOutputTransform { - pub fn new(output_key: String) -> Self { + pub fn new(output_key: &'static str) -> Self { Self { output_key } } } @@ -16,7 +16,7 @@ impl Postprocessor for OnnxOutputTransform { fn process(&self, data: PostprocessorData) -> Result { match data { PostprocessorData::OnnxOutput(onnx_output) => { - let output = onnx_output.get(self.output_key.as_str()).ok_or_else(|| { + let output = onnx_output.get(self.output_key).ok_or_else(|| { AIProxyError::OnnxOutputTransformError { message: format!( "Output key '{}' not found in the OnnxOutput.", diff --git a/ahnlich/ai/src/engine/ai/providers/processors/pooling.rs b/ahnlich/ai/src/engine/ai/providers/processors/pooling.rs index 0d2be85b..40b6510a 100644 --- a/ahnlich/ai/src/engine/ai/providers/processors/pooling.rs +++ b/ahnlich/ai/src/engine/ai/providers/processors/pooling.rs @@ -30,7 +30,7 @@ impl Postprocessor for RegularPooling { pub struct MeanPoolingBuilder; impl MeanPoolingBuilder { - pub fn with_attention_mask<'a>(&'a self, attention_mask: Array) -> MeanPooling { + pub fn with_attention_mask(&self, attention_mask: Array) -> MeanPooling { MeanPooling { attention_mask } } } diff --git a/ahnlich/ai/src/engine/ai/providers/processors/postprocessor.rs b/ahnlich/ai/src/engine/ai/providers/processors/postprocessor.rs index 54f580f2..7388f567 100644 --- a/ahnlich/ai/src/engine/ai/providers/processors/postprocessor.rs +++ b/ahnlich/ai/src/engine/ai/providers/processors/postprocessor.rs @@ -27,10 +27,8 @@ impl ORTTextPostprocessor { SupportedModels::AllMiniLML6V2 | SupportedModels::AllMiniLML12V2 | SupportedModels::BGEBaseEnV15 - | SupportedModels::BGELargeEnV15 => { - OnnxOutputTransform::new("last_hidden_state".to_string()) - } - SupportedModels::ClipVitB32Text => OnnxOutputTransform::new("text_embeds".to_string()), + | SupportedModels::BGELargeEnV15 => OnnxOutputTransform::new("last_hidden_state"), + SupportedModels::ClipVitB32Text => OnnxOutputTransform::new("text_embeds"), _ => Err(AIProxyError::ModelPostprocessingError { model_name: supported_model.to_string(), message: "Unsupported model for ORTTextPostprocessor".to_string(), @@ -105,9 +103,8 @@ pub struct ORTImagePostprocessor { impl ORTImagePostprocessor { pub fn load(supported_model: SupportedModels) -> Result { let output_transform = match supported_model { - SupportedModels::Resnet50 | SupportedModels::ClipVitB32Image => { - OnnxOutputTransform::new("image_embeds".to_string()) - } + SupportedModels::Resnet50 => OnnxOutputTransform::new("output"), + SupportedModels::ClipVitB32Image => OnnxOutputTransform::new("image_embeds"), _ => Err(AIProxyError::ModelPostprocessingError { model_name: supported_model.to_string(), message: "Unsupported model for ORTImagePostprocessor".to_string(), From 033487e94d22737199f3a695a7b7707c5b7dae87 Mon Sep 17 00:00:00 2001 From: Diretnan Domnan Date: Tue, 26 Nov 2024 21:36:14 +0100 Subject: [PATCH 11/15] Formatting Python files --- .../ahnlich_client_py/internals/ai_query.py | 65 ++++++++------ .../internals/ai_response.py | 86 +++++++++++-------- .../internals/bincode/__init__.py | 4 +- .../ahnlich_client_py/internals/db_query.py | 45 ++++++---- .../internals/db_response.py | 67 ++++++++------- .../internals/serde_binary/__init__.py | 2 +- .../internals/serde_types/__init__.py | 5 +- .../db_client/test_client_unit_commands.py | 2 - 8 files changed, 159 insertions(+), 117 deletions(-) diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_query.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_query.py index db9475c1..368bddbb 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_query.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_query.py @@ -1,8 +1,10 @@ # pyre-strict -from dataclasses import dataclass import typing -from ahnlich_client_py.internals import serde_types as st +from dataclasses import dataclass + from ahnlich_client_py.internals import bincode +from ahnlich_client_py.internals import serde_types as st + class AIModel: VARIANTS = [] # type: typing.Sequence[typing.Type[AIModel]] @@ -11,10 +13,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIModel) @staticmethod - def bincode_deserialize(input: bytes) -> 'AIModel': + def bincode_deserialize(input: bytes) -> "AIModel": v, buffer = bincode.deserialize(input, AIModel) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -59,6 +61,7 @@ class AIModel__ClipVitB32Text(AIModel): INDEX = 6 # type: int pass + AIModel.VARIANTS = [ AIModel__AllMiniLML6V2, AIModel__AllMiniLML12V2, @@ -77,10 +80,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIQuery) @staticmethod - def bincode_deserialize(input: bytes) -> 'AIQuery': + def bincode_deserialize(input: bytes) -> "AIQuery": v, buffer = bincode.deserialize(input, AIQuery) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -147,7 +150,9 @@ class AIQuery__DropNonLinearAlgorithmIndex(AIQuery): class AIQuery__Set(AIQuery): INDEX = 7 # type: int store: str - inputs: typing.Sequence[typing.Tuple["StoreInput", typing.Dict[str, "MetadataValue"]]] + inputs: typing.Sequence[ + typing.Tuple["StoreInput", typing.Dict[str, "MetadataValue"]] + ] preprocess_action: "PreprocessAction" @@ -188,6 +193,7 @@ class AIQuery__Ping(AIQuery): INDEX = 13 # type: int pass + AIQuery.VARIANTS = [ AIQuery__CreateStore, AIQuery__GetPred, @@ -215,10 +221,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIServerQuery) @staticmethod - def bincode_deserialize(input: bytes) -> 'AIServerQuery': + def bincode_deserialize(input: bytes) -> "AIServerQuery": v, buffer = bincode.deserialize(input, AIServerQuery) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -229,10 +235,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIStoreInputType) @staticmethod - def bincode_deserialize(input: bytes) -> 'AIStoreInputType': + def bincode_deserialize(input: bytes) -> "AIStoreInputType": v, buffer = bincode.deserialize(input, AIStoreInputType) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -247,6 +253,7 @@ class AIStoreInputType__Image(AIStoreInputType): INDEX = 1 # type: int pass + AIStoreInputType.VARIANTS = [ AIStoreInputType__RawString, AIStoreInputType__Image, @@ -260,10 +267,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Algorithm) @staticmethod - def bincode_deserialize(input: bytes) -> 'Algorithm': + def bincode_deserialize(input: bytes) -> "Algorithm": v, buffer = bincode.deserialize(input, Algorithm) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -290,6 +297,7 @@ class Algorithm__KDTree(Algorithm): INDEX = 3 # type: int pass + Algorithm.VARIANTS = [ Algorithm__EuclideanDistance, Algorithm__DotProductSimilarity, @@ -305,10 +313,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, MetadataValue) @staticmethod - def bincode_deserialize(input: bytes) -> 'MetadataValue': + def bincode_deserialize(input: bytes) -> "MetadataValue": v, buffer = bincode.deserialize(input, MetadataValue) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -323,6 +331,7 @@ class MetadataValue__Image(MetadataValue): INDEX = 1 # type: int value: typing.Sequence[st.uint8] + MetadataValue.VARIANTS = [ MetadataValue__RawString, MetadataValue__Image, @@ -336,10 +345,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, NonLinearAlgorithm) @staticmethod - def bincode_deserialize(input: bytes) -> 'NonLinearAlgorithm': + def bincode_deserialize(input: bytes) -> "NonLinearAlgorithm": v, buffer = bincode.deserialize(input, NonLinearAlgorithm) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -348,6 +357,7 @@ class NonLinearAlgorithm__KDTree(NonLinearAlgorithm): INDEX = 0 # type: int pass + NonLinearAlgorithm.VARIANTS = [ NonLinearAlgorithm__KDTree, ] @@ -360,10 +370,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Predicate) @staticmethod - def bincode_deserialize(input: bytes) -> 'Predicate': + def bincode_deserialize(input: bytes) -> "Predicate": v, buffer = bincode.deserialize(input, Predicate) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -394,6 +404,7 @@ class Predicate__NotIn(Predicate): key: str value: typing.Sequence["MetadataValue"] + Predicate.VARIANTS = [ Predicate__Equals, Predicate__NotEquals, @@ -409,10 +420,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, PredicateCondition) @staticmethod - def bincode_deserialize(input: bytes) -> 'PredicateCondition': + def bincode_deserialize(input: bytes) -> "PredicateCondition": v, buffer = bincode.deserialize(input, PredicateCondition) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -433,6 +444,7 @@ class PredicateCondition__Or(PredicateCondition): INDEX = 2 # type: int value: typing.Tuple["PredicateCondition", "PredicateCondition"] + PredicateCondition.VARIANTS = [ PredicateCondition__Value, PredicateCondition__And, @@ -447,10 +459,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, PreprocessAction) @staticmethod - def bincode_deserialize(input: bytes) -> 'PreprocessAction': + def bincode_deserialize(input: bytes) -> "PreprocessAction": v, buffer = bincode.deserialize(input, PreprocessAction) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -465,6 +477,7 @@ class PreprocessAction__ModelPreprocessing(PreprocessAction): INDEX = 1 # type: int pass + PreprocessAction.VARIANTS = [ PreprocessAction__NoPreprocessing, PreprocessAction__ModelPreprocessing, @@ -478,10 +491,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StoreInput) @staticmethod - def bincode_deserialize(input: bytes) -> 'StoreInput': + def bincode_deserialize(input: bytes) -> "StoreInput": v, buffer = bincode.deserialize(input, StoreInput) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -496,8 +509,8 @@ class StoreInput__Image(StoreInput): INDEX = 1 # type: int value: typing.Sequence[st.uint8] + StoreInput.VARIANTS = [ StoreInput__RawString, StoreInput__Image, ] - diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_response.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_response.py index 6b5fc608..e71b20ca 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_response.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_response.py @@ -1,8 +1,10 @@ # pyre-strict -from dataclasses import dataclass import typing -from ahnlich_client_py.internals import serde_types as st +from dataclasses import dataclass + from ahnlich_client_py.internals import bincode +from ahnlich_client_py.internals import serde_types as st + class AIModel: VARIANTS = [] # type: typing.Sequence[typing.Type[AIModel]] @@ -11,10 +13,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIModel) @staticmethod - def bincode_deserialize(input: bytes) -> 'AIModel': + def bincode_deserialize(input: bytes) -> "AIModel": v, buffer = bincode.deserialize(input, AIModel) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -59,6 +61,7 @@ class AIModel__ClipVitB32Text(AIModel): INDEX = 6 # type: int pass + AIModel.VARIANTS = [ AIModel__AllMiniLML6V2, AIModel__AllMiniLML12V2, @@ -77,10 +80,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIServerResponse) @staticmethod - def bincode_deserialize(input: bytes) -> 'AIServerResponse': + def bincode_deserialize(input: bytes) -> "AIServerResponse": v, buffer = bincode.deserialize(input, AIServerResponse) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -123,13 +126,21 @@ class AIServerResponse__Set(AIServerResponse): @dataclass(frozen=True) class AIServerResponse__Get(AIServerResponse): INDEX = 6 # type: int - value: typing.Sequence[typing.Tuple[typing.Optional["StoreInput"], typing.Dict[str, "MetadataValue"]]] + value: typing.Sequence[ + typing.Tuple[typing.Optional["StoreInput"], typing.Dict[str, "MetadataValue"]] + ] @dataclass(frozen=True) class AIServerResponse__GetSimN(AIServerResponse): INDEX = 7 # type: int - value: typing.Sequence[typing.Tuple[typing.Optional["StoreInput"], typing.Dict[str, "MetadataValue"], "Similarity"]] + value: typing.Sequence[ + typing.Tuple[ + typing.Optional["StoreInput"], + typing.Dict[str, "MetadataValue"], + "Similarity", + ] + ] @dataclass(frozen=True) @@ -143,6 +154,7 @@ class AIServerResponse__CreateIndex(AIServerResponse): INDEX = 9 # type: int value: st.uint64 + AIServerResponse.VARIANTS = [ AIServerResponse__Unit, AIServerResponse__Pong, @@ -165,10 +177,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIServerResult) @staticmethod - def bincode_deserialize(input: bytes) -> 'AIServerResult': + def bincode_deserialize(input: bytes) -> "AIServerResult": v, buffer = bincode.deserialize(input, AIServerResult) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -183,10 +195,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIStoreInfo) @staticmethod - def bincode_deserialize(input: bytes) -> 'AIStoreInfo': + def bincode_deserialize(input: bytes) -> "AIStoreInfo": v, buffer = bincode.deserialize(input, AIStoreInfo) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -197,10 +209,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, AIStoreInputType) @staticmethod - def bincode_deserialize(input: bytes) -> 'AIStoreInputType': + def bincode_deserialize(input: bytes) -> "AIStoreInputType": v, buffer = bincode.deserialize(input, AIStoreInputType) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -215,6 +227,7 @@ class AIStoreInputType__Image(AIStoreInputType): INDEX = 1 # type: int pass + AIStoreInputType.VARIANTS = [ AIStoreInputType__RawString, AIStoreInputType__Image, @@ -230,10 +243,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ConnectedClient) @staticmethod - def bincode_deserialize(input: bytes) -> 'ConnectedClient': + def bincode_deserialize(input: bytes) -> "ConnectedClient": v, buffer = bincode.deserialize(input, ConnectedClient) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -244,10 +257,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, MetadataValue) @staticmethod - def bincode_deserialize(input: bytes) -> 'MetadataValue': + def bincode_deserialize(input: bytes) -> "MetadataValue": v, buffer = bincode.deserialize(input, MetadataValue) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -262,6 +275,7 @@ class MetadataValue__Image(MetadataValue): INDEX = 1 # type: int value: typing.Sequence[st.uint8] + MetadataValue.VARIANTS = [ MetadataValue__RawString, MetadataValue__Image, @@ -275,10 +289,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Result) @staticmethod - def bincode_deserialize(input: bytes) -> 'Result': + def bincode_deserialize(input: bytes) -> "Result": v, buffer = bincode.deserialize(input, Result) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -293,6 +307,7 @@ class Result__Err(Result): INDEX = 1 # type: int value: str + Result.VARIANTS = [ Result__Ok, Result__Err, @@ -311,10 +326,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerInfo) @staticmethod - def bincode_deserialize(input: bytes) -> 'ServerInfo': + def bincode_deserialize(input: bytes) -> "ServerInfo": v, buffer = bincode.deserialize(input, ServerInfo) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -325,10 +340,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerType) @staticmethod - def bincode_deserialize(input: bytes) -> 'ServerType': + def bincode_deserialize(input: bytes) -> "ServerType": v, buffer = bincode.deserialize(input, ServerType) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -343,6 +358,7 @@ class ServerType__AI(ServerType): INDEX = 1 # type: int pass + ServerType.VARIANTS = [ ServerType__Database, ServerType__AI, @@ -357,10 +373,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Similarity) @staticmethod - def bincode_deserialize(input: bytes) -> 'Similarity': + def bincode_deserialize(input: bytes) -> "Similarity": v, buffer = bincode.deserialize(input, Similarity) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -371,10 +387,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StoreInput) @staticmethod - def bincode_deserialize(input: bytes) -> 'StoreInput': + def bincode_deserialize(input: bytes) -> "StoreInput": v, buffer = bincode.deserialize(input, StoreInput) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -389,6 +405,7 @@ class StoreInput__Image(StoreInput): INDEX = 1 # type: int value: typing.Sequence[st.uint8] + StoreInput.VARIANTS = [ StoreInput__RawString, StoreInput__Image, @@ -404,10 +421,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StoreUpsert) @staticmethod - def bincode_deserialize(input: bytes) -> 'StoreUpsert': + def bincode_deserialize(input: bytes) -> "StoreUpsert": v, buffer = bincode.deserialize(input, StoreUpsert) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -420,10 +437,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, SystemTime) @staticmethod - def bincode_deserialize(input: bytes) -> 'SystemTime': + def bincode_deserialize(input: bytes) -> "SystemTime": v, buffer = bincode.deserialize(input, SystemTime) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -437,9 +454,8 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Version) @staticmethod - def bincode_deserialize(input: bytes) -> 'Version': + def bincode_deserialize(input: bytes) -> "Version": v, buffer = bincode.deserialize(input, Version) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v - diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/bincode/__init__.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/bincode/__init__.py index 4e5e0837..38cbd7ff 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/bincode/__init__.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/bincode/__init__.py @@ -1,16 +1,16 @@ # Copyright (c) Facebook, Inc. and its affiliates # SPDX-License-Identifier: MIT OR Apache-2.0 -import dataclasses import collections +import dataclasses import io import struct import typing from copy import copy from typing import get_type_hints -from ahnlich_client_py.internals import serde_types as st from ahnlich_client_py.internals import serde_binary as sb +from ahnlich_client_py.internals import serde_types as st # Maximum length in practice for sequences (e.g. in Java). MAX_LENGTH = (1 << 31) - 1 diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_query.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_query.py index b6120b2b..b281f346 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_query.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_query.py @@ -1,8 +1,10 @@ # pyre-strict -from dataclasses import dataclass import typing -from ahnlich_client_py.internals import serde_types as st +from dataclasses import dataclass + from ahnlich_client_py.internals import bincode +from ahnlich_client_py.internals import serde_types as st + class Algorithm: VARIANTS = [] # type: typing.Sequence[typing.Type[Algorithm]] @@ -11,10 +13,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Algorithm) @staticmethod - def bincode_deserialize(input: bytes) -> 'Algorithm': + def bincode_deserialize(input: bytes) -> "Algorithm": v, buffer = bincode.deserialize(input, Algorithm) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -41,6 +43,7 @@ class Algorithm__KDTree(Algorithm): INDEX = 3 # type: int pass + Algorithm.VARIANTS = [ Algorithm__EuclideanDistance, Algorithm__DotProductSimilarity, @@ -59,10 +62,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Array) @staticmethod - def bincode_deserialize(input: bytes) -> 'Array': + def bincode_deserialize(input: bytes) -> "Array": v, buffer = bincode.deserialize(input, Array) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -73,10 +76,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, MetadataValue) @staticmethod - def bincode_deserialize(input: bytes) -> 'MetadataValue': + def bincode_deserialize(input: bytes) -> "MetadataValue": v, buffer = bincode.deserialize(input, MetadataValue) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -91,6 +94,7 @@ class MetadataValue__Image(MetadataValue): INDEX = 1 # type: int value: typing.Sequence[st.uint8] + MetadataValue.VARIANTS = [ MetadataValue__RawString, MetadataValue__Image, @@ -104,10 +108,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, NonLinearAlgorithm) @staticmethod - def bincode_deserialize(input: bytes) -> 'NonLinearAlgorithm': + def bincode_deserialize(input: bytes) -> "NonLinearAlgorithm": v, buffer = bincode.deserialize(input, NonLinearAlgorithm) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -116,6 +120,7 @@ class NonLinearAlgorithm__KDTree(NonLinearAlgorithm): INDEX = 0 # type: int pass + NonLinearAlgorithm.VARIANTS = [ NonLinearAlgorithm__KDTree, ] @@ -128,10 +133,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Predicate) @staticmethod - def bincode_deserialize(input: bytes) -> 'Predicate': + def bincode_deserialize(input: bytes) -> "Predicate": v, buffer = bincode.deserialize(input, Predicate) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -162,6 +167,7 @@ class Predicate__NotIn(Predicate): key: str value: typing.Sequence["MetadataValue"] + Predicate.VARIANTS = [ Predicate__Equals, Predicate__NotEquals, @@ -177,10 +183,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, PredicateCondition) @staticmethod - def bincode_deserialize(input: bytes) -> 'PredicateCondition': + def bincode_deserialize(input: bytes) -> "PredicateCondition": v, buffer = bincode.deserialize(input, PredicateCondition) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -201,6 +207,7 @@ class PredicateCondition__Or(PredicateCondition): INDEX = 2 # type: int value: typing.Tuple["PredicateCondition", "PredicateCondition"] + PredicateCondition.VARIANTS = [ PredicateCondition__Value, PredicateCondition__And, @@ -215,10 +222,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Query) @staticmethod - def bincode_deserialize(input: bytes) -> 'Query': + def bincode_deserialize(input: bytes) -> "Query": v, buffer = bincode.deserialize(input, Query) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -337,6 +344,7 @@ class Query__Ping(Query): INDEX = 15 # type: int pass + Query.VARIANTS = [ Query__CreateStore, Query__GetKey, @@ -366,9 +374,8 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerQuery) @staticmethod - def bincode_deserialize(input: bytes) -> 'ServerQuery': + def bincode_deserialize(input: bytes) -> "ServerQuery": v, buffer = bincode.deserialize(input, ServerQuery) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v - diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_response.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_response.py index acd3baa1..d1d0a6c4 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_response.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/db_response.py @@ -1,8 +1,10 @@ # pyre-strict -from dataclasses import dataclass import typing -from ahnlich_client_py.internals import serde_types as st +from dataclasses import dataclass + from ahnlich_client_py.internals import bincode +from ahnlich_client_py.internals import serde_types as st + @dataclass(frozen=True) class Array: @@ -14,10 +16,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Array) @staticmethod - def bincode_deserialize(input: bytes) -> 'Array': + def bincode_deserialize(input: bytes) -> "Array": v, buffer = bincode.deserialize(input, Array) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -30,10 +32,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ConnectedClient) @staticmethod - def bincode_deserialize(input: bytes) -> 'ConnectedClient': + def bincode_deserialize(input: bytes) -> "ConnectedClient": v, buffer = bincode.deserialize(input, ConnectedClient) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -44,10 +46,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, MetadataValue) @staticmethod - def bincode_deserialize(input: bytes) -> 'MetadataValue': + def bincode_deserialize(input: bytes) -> "MetadataValue": v, buffer = bincode.deserialize(input, MetadataValue) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -62,6 +64,7 @@ class MetadataValue__Image(MetadataValue): INDEX = 1 # type: int value: typing.Sequence[st.uint8] + MetadataValue.VARIANTS = [ MetadataValue__RawString, MetadataValue__Image, @@ -75,10 +78,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Result) @staticmethod - def bincode_deserialize(input: bytes) -> 'Result': + def bincode_deserialize(input: bytes) -> "Result": v, buffer = bincode.deserialize(input, Result) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -93,6 +96,7 @@ class Result__Err(Result): INDEX = 1 # type: int value: str + Result.VARIANTS = [ Result__Ok, Result__Err, @@ -111,10 +115,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerInfo) @staticmethod - def bincode_deserialize(input: bytes) -> 'ServerInfo': + def bincode_deserialize(input: bytes) -> "ServerInfo": v, buffer = bincode.deserialize(input, ServerInfo) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -125,10 +129,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerResponse) @staticmethod - def bincode_deserialize(input: bytes) -> 'ServerResponse': + def bincode_deserialize(input: bytes) -> "ServerResponse": v, buffer = bincode.deserialize(input, ServerResponse) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -177,7 +181,9 @@ class ServerResponse__Get(ServerResponse): @dataclass(frozen=True) class ServerResponse__GetSimN(ServerResponse): INDEX = 7 # type: int - value: typing.Sequence[typing.Tuple["Array", typing.Dict[str, "MetadataValue"], "Similarity"]] + value: typing.Sequence[ + typing.Tuple["Array", typing.Dict[str, "MetadataValue"], "Similarity"] + ] @dataclass(frozen=True) @@ -191,6 +197,7 @@ class ServerResponse__CreateIndex(ServerResponse): INDEX = 9 # type: int value: st.uint64 + ServerResponse.VARIANTS = [ ServerResponse__Unit, ServerResponse__Pong, @@ -213,10 +220,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerResult) @staticmethod - def bincode_deserialize(input: bytes) -> 'ServerResult': + def bincode_deserialize(input: bytes) -> "ServerResult": v, buffer = bincode.deserialize(input, ServerResult) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -227,10 +234,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, ServerType) @staticmethod - def bincode_deserialize(input: bytes) -> 'ServerType': + def bincode_deserialize(input: bytes) -> "ServerType": v, buffer = bincode.deserialize(input, ServerType) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -245,6 +252,7 @@ class ServerType__AI(ServerType): INDEX = 1 # type: int pass + ServerType.VARIANTS = [ ServerType__Database, ServerType__AI, @@ -259,10 +267,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Similarity) @staticmethod - def bincode_deserialize(input: bytes) -> 'Similarity': + def bincode_deserialize(input: bytes) -> "Similarity": v, buffer = bincode.deserialize(input, Similarity) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -276,10 +284,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StoreInfo) @staticmethod - def bincode_deserialize(input: bytes) -> 'StoreInfo': + def bincode_deserialize(input: bytes) -> "StoreInfo": v, buffer = bincode.deserialize(input, StoreInfo) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -292,10 +300,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, StoreUpsert) @staticmethod - def bincode_deserialize(input: bytes) -> 'StoreUpsert': + def bincode_deserialize(input: bytes) -> "StoreUpsert": v, buffer = bincode.deserialize(input, StoreUpsert) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -308,10 +316,10 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, SystemTime) @staticmethod - def bincode_deserialize(input: bytes) -> 'SystemTime': + def bincode_deserialize(input: bytes) -> "SystemTime": v, buffer = bincode.deserialize(input, SystemTime) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v @@ -325,9 +333,8 @@ def bincode_serialize(self) -> bytes: return bincode.serialize(self, Version) @staticmethod - def bincode_deserialize(input: bytes) -> 'Version': + def bincode_deserialize(input: bytes) -> "Version": v, buffer = bincode.deserialize(input, Version) if buffer: - raise st.DeserializationError("Some input bytes were not read"); + raise st.DeserializationError("Some input bytes were not read") return v - diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_binary/__init__.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_binary/__init__.py index 0730bd23..a71b03f5 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_binary/__init__.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_binary/__init__.py @@ -7,8 +7,8 @@ Note: This internal module is currently only meant to share code between the BCS and bincode formats. Internal APIs could change in the future. """ -import dataclasses import collections +import dataclasses import io import typing from typing import get_type_hints diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_types/__init__.py b/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_types/__init__.py index 6d72f027..1c85909c 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_types/__init__.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/internals/serde_types/__init__.py @@ -1,9 +1,10 @@ # Copyright (c) Facebook, Inc. and its affiliates # SPDX-License-Identifier: MIT OR Apache-2.0 -import numpy as np -from dataclasses import dataclass import typing +from dataclasses import dataclass + +import numpy as np class SerializationError(ValueError): diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/tests/db_client/test_client_unit_commands.py b/sdk/ahnlich-client-py/ahnlich_client_py/tests/db_client/test_client_unit_commands.py index b31879dc..ed7ecdc9 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/tests/db_client/test_client_unit_commands.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/tests/db_client/test_client_unit_commands.py @@ -6,7 +6,6 @@ def test_client_sends_ping_to_db_success(module_scopped_ahnlich_db): port = module_scopped_ahnlich_db db_client = clients.AhnlichDBClient(address="127.0.0.1", port=port) try: - response: db_response.ServerResult = db_client.ping() assert len(response.results) == 1 assert response.results[0] == db_response.Result__Ok( @@ -18,7 +17,6 @@ def test_client_sends_ping_to_db_success(module_scopped_ahnlich_db): db_client.cleanup() raise e finally: - db_client.cleanup() From bf516b094b4ef4311dc837519a6a14e4f39a4172 Mon Sep 17 00:00:00 2001 From: Diretnan Domnan Date: Tue, 26 Nov 2024 21:45:57 +0100 Subject: [PATCH 12/15] Fix test_ai_store_binary_actions --- ahnlich/ai/src/tests/aiproxy_test.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ahnlich/ai/src/tests/aiproxy_test.rs b/ahnlich/ai/src/tests/aiproxy_test.rs index 3fb2a533..2a906202 100644 --- a/ahnlich/ai/src/tests/aiproxy_test.rs +++ b/ahnlich/ai/src/tests/aiproxy_test.rs @@ -833,7 +833,7 @@ async fn test_ai_proxy_binary_store_actions() { updated: 0, }))); expected.push(Err( - "Image Dimensions [(821, 547)] does not match the expected model dimensions [(224, 224)]" + "Image Dimensions [(547, 821)] does not match the expected model dimensions [(224, 224)]" .to_string(), )); expected.push(Ok(AIServerResponse::Del(1))); From aab36d6e7756064c239b0e0a83573f12832ef921 Mon Sep 17 00:00:00 2001 From: Diretnan Domnan Date: Tue, 26 Nov 2024 21:55:37 +0100 Subject: [PATCH 13/15] Fix test_set_in_store_parse --- ahnlich/dsl/src/tests/ai.rs | 2 +- sdk/ahnlich-client-py/demo_tracing.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/ahnlich/dsl/src/tests/ai.rs b/ahnlich/dsl/src/tests/ai.rs index 2f59f4f5..78d13caa 100644 --- a/ahnlich/dsl/src/tests/ai.rs +++ b/ahnlich/dsl/src/tests/ai.rs @@ -343,7 +343,7 @@ fn test_set_in_store_parse() { panic!("Unexpected error pattern found") }; assert_eq!((start, end), (0, 63)); - let input = r#"SET (([This is the life of Haks paragraphed], {name: Haks, category: dev}), ([This is the life of Deven paragraphed], {name: Deven, category: dev})) in geo preprocessaction erroriftokensexceeded"#; + let input = r#"SET (([This is the life of Haks paragraphed], {name: Haks, category: dev}), ([This is the life of Deven paragraphed], {name: Deven, category: dev})) in geo preprocessaction nopreprocessing"#; assert_eq!( parse_ai_query(input).expect("Could not parse query input"), vec![AIQuery::Set { diff --git a/sdk/ahnlich-client-py/demo_tracing.py b/sdk/ahnlich-client-py/demo_tracing.py index d96f981b..12992613 100644 --- a/sdk/ahnlich-client-py/demo_tracing.py +++ b/sdk/ahnlich-client-py/demo_tracing.py @@ -63,9 +63,7 @@ def tracer(span_id): builder.set( store_name=ai_store_payload_with_predicates["store_name"], inputs=store_inputs, - preprocess_action=ai_query.PreprocessAction__RawString( - ai_query.StringAction__ErrorIfTokensExceed() - ), + preprocess_action=ai_query.PreprocessAction__NoPreprocessing(), ) builder.create_store(**ai_store_payload_no_predicates) builder.list_stores() From c90b72e3b3db2278e8ca90b94243c6d00c850c26 Mon Sep 17 00:00:00 2001 From: Diretnan Domnan Date: Tue, 26 Nov 2024 22:01:27 +0100 Subject: [PATCH 14/15] Fixing python tests with previous preprocess modes --- .../tests/ai_client/test_ai_client_store_commands.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/tests/ai_client/test_ai_client_store_commands.py b/sdk/ahnlich-client-py/ahnlich_client_py/tests/ai_client/test_ai_client_store_commands.py index 21937ab7..4342f51b 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/tests/ai_client/test_ai_client_store_commands.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/tests/ai_client/test_ai_client_store_commands.py @@ -53,9 +53,7 @@ def test_ai_client_get_pred(spin_up_ahnlich_ai): builder.set( store_name=ai_store_payload_with_predicates["store_name"], inputs=store_inputs, - preprocess_action=ai_query.PreprocessAction__RawString( - ai_query.StringAction__ErrorIfTokensExceed() - ), + preprocess_action=ai_query.PreprocessAction__NoPreprocessing(), ) expected = ai_response.AIServerResult( results=[ @@ -196,9 +194,7 @@ def test_ai_client_del_key(spin_up_ahnlich_ai): builder.set( store_name=ai_store_payload_with_predicates["store_name"], inputs=store_inputs, - preprocess_action=ai_query.PreprocessAction__RawString( - ai_query.StringAction__ErrorIfTokensExceed() - ), + preprocess_action=ai_query.PreprocessAction__NoPreprocessing(), ) expected = ai_response.AIServerResult( results=[ From 4f7d123b62155dc50abb2d088ae6b4c17afe68d0 Mon Sep 17 00:00:00 2001 From: Diretnan Domnan Date: Tue, 26 Nov 2024 22:15:42 +0100 Subject: [PATCH 15/15] Fixing python tests with previous preprocess modes on merge --- examples/python/book-search/insert_book.py | 4 +--- .../ahnlich_client_py/builders/non_blocking/ai.py | 1 - .../ahnlich_client_py/builders/non_blocking/db.py | 1 - .../tests/ai_client/test_async_client_commands.py | 4 +--- .../tests/db_client/test_async_client_commands.py | 1 - 5 files changed, 2 insertions(+), 9 deletions(-) diff --git a/examples/python/book-search/insert_book.py b/examples/python/book-search/insert_book.py index bcd80751..00c69a52 100644 --- a/examples/python/book-search/insert_book.py +++ b/examples/python/book-search/insert_book.py @@ -20,9 +20,7 @@ async def set_client(ai_client, inputs): response = await ai_client.set( store_name=ai_store_payload_with_predicates["store_name"], inputs=inputs, - preprocess_action=ai_query.PreprocessAction__RawString( - ai_query.StringAction__ErrorIfTokensExceed() - ), + preprocess_action=ai_query.PreprocessAction__NoPreprocessing(), ) print(response) diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/builders/non_blocking/ai.py b/sdk/ahnlich-client-py/ahnlich_client_py/builders/non_blocking/ai.py index 44ee542c..bbe8f153 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/builders/non_blocking/ai.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/builders/non_blocking/ai.py @@ -6,7 +6,6 @@ class AsyncAhnlichAIRequestBuilder(AhnlichAIRequestBuilder): - def __init__(self, tracing_id: str = None, client: BaseClient = None) -> None: self.queries: typing.List[ai_query.AIQuery] = [] self.tracing_id = tracing_id diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/builders/non_blocking/db.py b/sdk/ahnlich-client-py/ahnlich_client_py/builders/non_blocking/db.py index 10442ac7..a28f9cb7 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/builders/non_blocking/db.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/builders/non_blocking/db.py @@ -6,7 +6,6 @@ class AsyncAhnlichDBRequestBuilder(AhnlichDBRequestBuilder): - def __init__(self, tracing_id: str = None, client: BaseClient = None) -> None: self.queries: typing.List[db_query.Query] = [] self.tracing_id = tracing_id diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/tests/ai_client/test_async_client_commands.py b/sdk/ahnlich-client-py/ahnlich_client_py/tests/ai_client/test_async_client_commands.py index 69e5f81a..f29fa01c 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/tests/ai_client/test_async_client_commands.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/tests/ai_client/test_async_client_commands.py @@ -116,9 +116,7 @@ async def test_ai_client_get_pred(spin_up_ahnlich_ai): builder.set( store_name=ai_store_payload_with_predicates["store_name"], inputs=store_inputs, - preprocess_action=ai_query.PreprocessAction__RawString( - ai_query.StringAction__ErrorIfTokensExceed() - ), + preprocess_action=ai_query.PreprocessAction__NoPreprocessing(), ) expected = ai_response.AIServerResult( results=[ diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/tests/db_client/test_async_client_commands.py b/sdk/ahnlich-client-py/ahnlich_client_py/tests/db_client/test_async_client_commands.py index 26b15efb..0fa601d0 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/tests/db_client/test_async_client_commands.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/tests/db_client/test_async_client_commands.py @@ -23,7 +23,6 @@ async def test_client_sends_ping_to_db_success(module_scopped_ahnlich_db): port = module_scopped_ahnlich_db db_client = non_blocking.AhnlichDBClient(address="127.0.0.1", port=port) try: - response: db_response.ServerResult = await db_client.ping() assert len(response.results) == 1 assert response.results[0] == db_response.Result__Ok(