Skip to content

Commit

Permalink
Postprocessor working on images
Browse files Browse the repository at this point in the history
  • Loading branch information
HAKSOAT committed Nov 25, 2024
1 parent 0a10b48 commit b405579
Show file tree
Hide file tree
Showing 10 changed files with 271 additions and 134 deletions.
42 changes: 21 additions & 21 deletions ahnlich/ai/src/engine/ai/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use ahnlich_types::{
keyval::{StoreInput, StoreKey},
};
use image::{DynamicImage, GenericImageView, ImageFormat, ImageReader};
use ndarray::ArrayView;
use ndarray::{ArrayView, Ix4};
use ndarray::{Array, Ix3};
use nonzero_ext::nonzero;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
Expand Down Expand Up @@ -246,14 +246,15 @@ impl fmt::Display for InputAction {
#[derive(Debug)]
pub enum ModelInput {
Texts(Vec<Encoding>),
Images(Vec<ImageArray>),
Images(Array<f32, Ix4>),
}

#[derive(Debug, Clone)]
pub struct ImageArray {
array: Array<f32, Ix3>,
image: DynamicImage,
image_format: ImageFormat
image_format: ImageFormat,
onnx_transformed: bool
}

impl ImageArray {
Expand Down Expand Up @@ -288,13 +289,17 @@ impl ImageArray {
.map_err(|_| AIProxyError::ImageBytesDecodeError)?
.mapv(f32::from);

Ok(ImageArray { array, image, image_format: image_format.to_owned() })
Ok(ImageArray { array, image, image_format: image_format.to_owned(), onnx_transformed: false })
}

// Swapping axes from [rows, columns, channels] to [channels, rows, columns] for ONNX
pub fn onnx_transform(&mut self) {
if self.onnx_transformed {
return;
}
self.array.swap_axes(1, 2);
self.array.swap_axes(0, 1);
self.onnx_transformed = true;
}

pub fn view(&self) -> ArrayView<f32, Ix3> {
Expand Down Expand Up @@ -324,7 +329,7 @@ impl ImageArray {
let array = Array::from_shape_vec(shape, flattened_pixels)
.map_err(|_| AIProxyError::ImageResizeError)?
.mapv(f32::from);
Ok(ImageArray { array, image: resized_img, image_format: self.image_format })
Ok(ImageArray { array, image: resized_img, image_format: self.image_format, onnx_transformed: false })
}

pub fn crop(&self, x: u32, y: u32, width: u32, height: u32) -> Result<Self, AIProxyError> {
Expand All @@ -336,15 +341,21 @@ impl ImageArray {
let array = Array::from_shape_vec(shape, flattened_pixels)
.map_err(|_| AIProxyError::ImageCropError)?
.mapv(f32::from);
Ok(ImageArray { array, image: cropped_img, image_format: self.image_format })
Ok(ImageArray { array, image: cropped_img, image_format: self.image_format, onnx_transformed: false })
}

pub fn image_dim(&self) -> (NonZeroUsize, NonZeroUsize) {
let shape = self.array.shape();
(
NonZeroUsize::new(shape[1]).expect("Array columns should be non-zero"),
NonZeroUsize::new(shape[0]).expect("Array rows should be non-zero"),
) // (width, height)
match self.onnx_transformed {
true => (
NonZeroUsize::new(shape[2]).expect("Array columns should be non-zero"),
NonZeroUsize::new(shape[1]).expect("Array channels should be non-zero"),
), // (width, channels)
false => (
NonZeroUsize::new(shape[1]).expect("Array columns should be non-zero"),
NonZeroUsize::new(shape[0]).expect("Array rows should be non-zero"),
) // (width, height)
}
}
}

Expand All @@ -367,17 +378,6 @@ impl<'de> Deserialize<'de> for ImageArray {
}
}

// impl TryFrom<StoreInput> for ModelInput {
// type Error = AIProxyError;
//
// fn try_from(value: StoreInput) -> Result<Self, Self::Error> {
// match value {
// StoreInput::RawString(s) => Ok(ModelInput::Text(s)),
// StoreInput::Image(bytes) => Ok(ModelInput::Image(ImageArray::try_new(bytes)?)),
// }
// }
// }

impl From<&ModelInput> for AIStoreInputType {
fn from(value: &ModelInput) -> AIStoreInputType {
match value {
Expand Down
3 changes: 2 additions & 1 deletion ahnlich/ai/src/engine/ai/providers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub(crate) mod ort;
mod ort_helper;
mod processors;
pub mod processors;


use crate::cli::server::SupportedModels;
use crate::engine::ai::models::{InputAction, ModelInput};
Expand Down
90 changes: 54 additions & 36 deletions ahnlich/ai/src/engine/ai/providers/ort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ use fallible_collections::FallibleVec;
use hf_hub::{api::sync::ApiBuilder, Cache};
use itertools::Itertools;
use rayon::iter::Either;
use ort::{Session, Value};
use ort::{Session, SessionOutputs, Value};
use rayon::prelude::*;

use ahnlich_types::keyval::StoreKey;
use ndarray::{Array, Array1, Axis, Ix2, Ix3, Ix4, IxDyn, IxDynImpl};
use ndarray::{Array, Array1, ArrayView, Axis, Ix2, Ix3, Ix4, IxDyn, IxDynImpl};
use std::convert::TryFrom;
use std::default::Default;
use std::fmt;
Expand All @@ -22,7 +22,7 @@ use crate::engine::ai::providers::processors::preprocessor::{ImagePreprocessorFi
use crate::engine::ai::providers::ort_helper::normalize;
use ndarray::s;
use tokenizers::Tokenizer;
use crate::engine::ai::providers::processors::postprocessor::{ORTPostprocessor, ORTTextPostprocessor};
use crate::engine::ai::providers::processors::postprocessor::{ORTImagePostprocessor, ORTPostprocessor, ORTTextPostprocessor};

#[derive(Default)]
pub struct ORTProvider {
Expand Down Expand Up @@ -215,44 +215,46 @@ impl ORTProvider {
}
}

pub fn batch_inference_image(&self, inputs: Vec<ImageArray>) -> Result<Vec<StoreKey>, AIProxyError> {
pub fn postprocess_image_inference(&self, embeddings: SessionOutputs) -> Result<Array<f32, Ix2>, AIProxyError> {
match &self.postprocessor {
Some(ORTPostprocessor::Image(postprocessor)) => {
let output_data = postprocessor.process(embeddings)
.map_err(
|e| AIProxyError::ModelProviderPostprocessingError(
format!("Postprocessing failed for {:?} with error: {}",
self.supported_models.unwrap().to_string(), e)
))?;
Ok(output_data)
}
_ => Err(AIProxyError::ModelPostprocessingError {
model_name: self.supported_models.unwrap().to_string(),
message: "Postprocessor not initialized".to_string(),
})
}
}

pub fn batch_inference_image(&self, inputs: Array<f32, Ix4>) -> Result<Array<f32, Ix2>, AIProxyError> {
let model = match &self.model {
Some(ORTModel::Image(model)) => model,
_ => return Err(AIProxyError::AIModelNotSupported { model_name: self.supported_models.unwrap().to_string() }),
};
let pixel_values_array = self.preprocess_images(inputs)?;
match &model.session {
Some(session) => {
let session_inputs = ort::inputs![
model.input_params.first().expect("Hardcoded in parameters")
.as_str() => pixel_values_array.view(),
.as_str() => inputs.view(),
].map_err(|e| AIProxyError::ModelProviderPreprocessingError(e.to_string()))?;

let outputs = session.run(session_inputs)
.map_err(|e| AIProxyError::ModelProviderRunInferenceError(e.to_string()))?;
let last_hidden_state_key = match outputs.len() {
1 => outputs.keys().next().unwrap(),
_ => model.output_param.as_str(),
};

let output_data = outputs[last_hidden_state_key]
.try_extract_tensor::<f32>()
.map_err(|e| AIProxyError::ModelProviderPostprocessingError(e.to_string()))?;
let store_keys = output_data
.axis_iter(Axis(0))
.into_par_iter()
.map(|row| {
let embeddings = normalize(row.as_slice().unwrap());
StoreKey(<Array1<f32>>::from(embeddings))
})
.collect();
Ok(store_keys)
let embeddings = self.postprocess_image_inference(outputs)?;
Ok(embeddings)
}
None => Err(AIProxyError::AIModelNotInitialized)
}
}

pub fn batch_inference_text(&self, encodings: Vec<Encoding>) -> Result<Vec<StoreKey>, AIProxyError> {
pub fn batch_inference_text(&self, encodings: Vec<Encoding>) -> Result<Array<f32, Ix2>, AIProxyError> {
let model = match &self.model {
Some(ORTModel::Text(model)) => model,
_ => return Err(AIProxyError::AIModelNotSupported { model_name: self.supported_models.unwrap().to_string() }),
Expand Down Expand Up @@ -338,13 +340,7 @@ impl ORTProvider {
let session_output = session_output
.to_owned();
let embeddings = self.postprocess_text_embeddings(session_output, attention_mask_array)?;
println!("Embeddings: {:?}", embeddings);
let store_keys = embeddings
.axis_iter(Axis(0))
.into_par_iter()
.map(|embedding| StoreKey(<Array1<f32>>::from(embedding.to_owned())))
.collect();
Ok(store_keys)
Ok(embeddings.to_owned())
}
None => Err(AIProxyError::AIModelNotInitialized),
}
Expand Down Expand Up @@ -407,8 +403,9 @@ impl ProviderTrait for ORTProvider {
}));
let mut preprocessor = ORTImagePreprocessor::default();
preprocessor.load(model_repo, preprocessor_files)?;
self.preprocessor = Some(ORTPreprocessor::Image(preprocessor)
);
self.preprocessor = Some(ORTPreprocessor::Image(preprocessor));
let postprocessor = ORTImagePostprocessor::load(supported_model)?;
self.postprocessor = Some(ORTPostprocessor::Image(postprocessor));
},
ORTModel::Text(ORTTextModel {
weights_file,
Expand Down Expand Up @@ -487,14 +484,35 @@ impl ProviderTrait for ORTProvider {
) -> Result<Vec<StoreKey>, AIProxyError> {

match input {
ModelInput::Images(images) => self.batch_inference_image(images),
ModelInput::Images(images) => {
let mut store_keys: Vec<StoreKey> = FallibleVec::try_with_capacity(images.len())?;

for batch_image in images.axis_chunks_iter(Axis(0), 16).into_iter() {
let embeddings = self.batch_inference_image(batch_image.to_owned())?;
let new_store_keys: Vec<StoreKey> = embeddings
.axis_iter(Axis(0))
.into_par_iter()
.map(|embedding| StoreKey(<Array1<f32>>::from(embedding.to_owned()))
)
.collect();
store_keys.extend(new_store_keys);
}
Ok(store_keys)
},
ModelInput::Texts(encodings) => {
let mut store_keys: Vec<_> = FallibleVec::try_with_capacity(
let mut store_keys: Vec<StoreKey> = FallibleVec::try_with_capacity(
encodings.len()
)?;

for batch_encoding in encodings.into_iter().chunks(16).into_iter() {
store_keys.extend(self.batch_inference_text(batch_encoding.collect())?);
let embeddings = self.batch_inference_text(batch_encoding.collect())?;
let new_store_keys: Vec<StoreKey> = embeddings
.axis_iter(Axis(0))
.into_par_iter()
.map(|embedding| StoreKey(<Array1<f32>>::from(embedding.to_owned()))
)
.collect();
store_keys.extend(new_store_keys);
}
Ok(store_keys)
},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use image::image_dimensions;
use ndarray::{ArrayView, Ix3};
use std::sync::Mutex;
use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
use crate::engine::ai::providers::processors::{Preprocessor, PreprocessorData};
use crate::error::AIProxyError;
Expand All @@ -9,16 +11,22 @@ impl Preprocessor for ImageArrayToNdArray {
fn process(&self, data: PreprocessorData) -> Result<PreprocessorData, AIProxyError> {
match data {
PreprocessorData::ImageArray(mut arrays) => {
let array_views: Vec<ArrayView<f32, Ix3>> = arrays
.par_iter_mut()
.map(|image_arr| {
image_arr.onnx_transform();
image_arr.view()
})
.collect();
let mut array_shapes = Mutex::new(vec![]);
let mut array_views = Mutex::new(vec![]);
arrays.par_iter_mut().for_each(|image_arr| {
image_arr.onnx_transform();
array_shapes.lock().unwrap().push(image_arr.image_dim());
array_views.lock().unwrap().push(image_arr.view());
});

let array_shapes = array_shapes.into_inner().unwrap();
let array_views = array_views.into_inner().unwrap();

let pixel_values_array = ndarray::stack(ndarray::Axis(0), &array_views)
.map_err(|e| AIProxyError::EmbeddingShapeError(e.to_string()))?;
.map_err(|e| AIProxyError::ImageArrayToNdArrayError {
message: format!("Images must have same dimensions, instead found: {:?}. \
NB: Dimensions listed are not in same order as images provided.", array_shapes),
})?;
Ok(PreprocessorData::NdArray3C(pixel_values_array))
}
_ => Err(AIProxyError::ImageArrayToNdArrayError {
Expand Down
16 changes: 15 additions & 1 deletion ahnlich/ai/src/engine/ai/providers/processors/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::engine::ai::models::ImageArray;
use crate::error::AIProxyError;
use ndarray::{Array, Ix2, Ix3, Ix4};
use ort::SessionOutputs;
use tokenizers::Encoding;

pub mod normalize;
Expand All @@ -12,6 +13,7 @@ pub mod preprocessor;
pub mod tokenize;
pub mod postprocessor;
pub mod pooling;
mod onnx_output_transform;

pub const CONV_NEXT_FEATURE_EXTRACTOR_CENTER_CROP_THRESHOLD: u32 = 384;

Expand All @@ -30,7 +32,19 @@ pub enum PreprocessorData {
EncodedText(Vec<Encoding>),
}

pub enum PostprocessorData {
impl PreprocessorData {
pub fn into_ndarray3c(self) -> Result<Array<f32, Ix4>, AIProxyError> {
match self {
PreprocessorData::NdArray3C(array) => Ok(array),
_ => Err(AIProxyError::ModelProviderPreprocessingError(
"`into_ndarray3c` only works for PreprocessorData::NdArray3C".to_string()
)),
}
}
}

pub enum PostprocessorData<'r, 's> {
OnnxOutput(SessionOutputs<'r, 's>),
NdArray2(Array<f32, Ix2>),
NdArray3(Array<f32, Ix3>)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
use ndarray::{Ix2, Ix3};
use crate::engine::ai::providers::processors::{Postprocessor, PostprocessorData};
use crate::error::AIProxyError;


pub struct OnnxOutputTransform {
output_key: String
}

impl OnnxOutputTransform {
pub fn new(output_key: String) -> Self {
Self { output_key }
}
}

impl Postprocessor for OnnxOutputTransform {
fn process(&self, data: PostprocessorData) -> Result<PostprocessorData, AIProxyError> {
match data {
PostprocessorData::OnnxOutput(onnx_output) => {
let output = onnx_output.get(self.output_key.as_str())
.ok_or_else(|| AIProxyError::OnnxOutputTransformError {
message: format!("Output key '{}' not found in the OnnxOutput.", self.output_key),
})?;
let output = output.try_extract_tensor::<f32>().map_err(
|_| AIProxyError::OnnxOutputTransformError {
message: "Failed to extract tensor from OnnxOutput.".to_string(),
}
)?;
match output.ndim() {
2 => {
let output = output.into_dimensionality::<Ix2>().map_err(
|_| AIProxyError::OnnxOutputTransformError {
message: "Failed to convert Dyn tensor to 2D array.".to_string(),
}
)?;
Ok(PostprocessorData::NdArray2(output.to_owned()))
},
3 => {
let output = output.into_dimensionality::<Ix3>().map_err(
|_| AIProxyError::OnnxOutputTransformError {
message: "Failed to convert Dyn tensor to 3D array.".to_string(),
}
)?;
Ok(PostprocessorData::NdArray3(output.to_owned()))
},
_ => Err(AIProxyError::OnnxOutputTransformError {
message: "Only 2D and 3D tensors are supported.".to_string(),
}),
}
}
_ => Err(AIProxyError::OnnxOutputTransformError {
message: "Only OnnxOutput is supported".to_string(),
}),
}
}
}
Loading

0 comments on commit b405579

Please sign in to comment.