Skip to content

Commit

Permalink
Preprocessors and Postprocessors now work
Browse files Browse the repository at this point in the history
  • Loading branch information
HAKSOAT committed Nov 26, 2024
1 parent b405579 commit 92cddd2
Show file tree
Hide file tree
Showing 20 changed files with 776 additions and 750 deletions.
49 changes: 33 additions & 16 deletions ahnlich/ai/src/engine/ai/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@ use crate::engine::ai::providers::ProviderTrait;
use crate::error::AIProxyError;
use ahnlich_types::{
ai::{AIModel, AIStoreInputType},
keyval::{StoreInput, StoreKey},
keyval::StoreKey,
};
use image::{DynamicImage, GenericImageView, ImageFormat, ImageReader};
use ndarray::{ArrayView, Ix4};
use ndarray::{Array, Ix3};
use ndarray::{ArrayView, Ix4};
use nonzero_ext::nonzero;
use serde::de::Error as DeError;
use serde::ser::Error as SerError;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::fmt;
use std::io::Cursor;
use std::num::NonZeroUsize;
use std::path::Path;
use strum::Display;
use serde::ser::Error as SerError;
use serde::de::Error as DeError;
use tokenizers::Encoding;

#[derive(Display)]
Expand Down Expand Up @@ -254,7 +254,7 @@ pub struct ImageArray {
array: Array<f32, Ix3>,
image: DynamicImage,
image_format: ImageFormat,
onnx_transformed: bool
onnx_transformed: bool,
}

impl ImageArray {
Expand Down Expand Up @@ -289,7 +289,12 @@ impl ImageArray {
.map_err(|_| AIProxyError::ImageBytesDecodeError)?
.mapv(f32::from);

Ok(ImageArray { array, image, image_format: image_format.to_owned(), onnx_transformed: false })
Ok(ImageArray {
array,
image,
image_format: image_format.to_owned(),
onnx_transformed: false,
})
}

// Swapping axes from [rows, columns, channels] to [channels, rows, columns] for ONNX
Expand All @@ -308,28 +313,35 @@ impl ImageArray {

pub fn get_bytes(&self) -> Result<Vec<u8>, AIProxyError> {
let mut buffer = Cursor::new(Vec::new());
let _ = &self.image
let _ = &self
.image
.write_to(&mut buffer, self.image_format)
.map_err(|_| AIProxyError::ImageBytesEncodeError)?;
let bytes = buffer.into_inner();
Ok(bytes)
}

pub fn resize(&self, width: u32, height: u32, filter: Option<image::imageops::FilterType>) -> Result<Self, AIProxyError> {
pub fn resize(
&self,
width: u32,
height: u32,
filter: Option<image::imageops::FilterType>,
) -> Result<Self, AIProxyError> {
let filter_type = filter.unwrap_or(image::imageops::FilterType::CatmullRom);
let resized_img = self.image.resize_exact(
width,
height,
filter_type,
);
let resized_img = self.image.resize_exact(width, height, filter_type);
let channels = resized_img.color().channel_count();
let shape = (height as usize, width as usize, channels as usize);

let flattened_pixels = resized_img.clone().into_bytes();
let array = Array::from_shape_vec(shape, flattened_pixels)
.map_err(|_| AIProxyError::ImageResizeError)?
.mapv(f32::from);
Ok(ImageArray { array, image: resized_img, image_format: self.image_format, onnx_transformed: false })
Ok(ImageArray {
array,
image: resized_img,
image_format: self.image_format,
onnx_transformed: false,
})
}

pub fn crop(&self, x: u32, y: u32, width: u32, height: u32) -> Result<Self, AIProxyError> {
Expand All @@ -341,7 +353,12 @@ impl ImageArray {
let array = Array::from_shape_vec(shape, flattened_pixels)
.map_err(|_| AIProxyError::ImageCropError)?
.mapv(f32::from);
Ok(ImageArray { array, image: cropped_img, image_format: self.image_format, onnx_transformed: false })
Ok(ImageArray {
array,
image: cropped_img,
image_format: self.image_format,
onnx_transformed: false,
})
}

pub fn image_dim(&self) -> (NonZeroUsize, NonZeroUsize) {
Expand All @@ -354,7 +371,7 @@ impl ImageArray {
false => (
NonZeroUsize::new(shape[1]).expect("Array columns should be non-zero"),
NonZeroUsize::new(shape[0]).expect("Array rows should be non-zero"),
) // (width, height)
), // (width, height)
}
}
}
Expand Down
2 changes: 0 additions & 2 deletions ahnlich/ai/src/engine/ai/providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ pub(crate) mod ort;
mod ort_helper;
pub mod processors;


use crate::cli::server::SupportedModels;
use crate::engine::ai::models::{InputAction, ModelInput};
use crate::engine::ai::providers::ort::ORTProvider;
Expand All @@ -11,7 +10,6 @@ use ahnlich_types::keyval::StoreKey;
use std::path::Path;
use strum::EnumIter;


#[derive(Debug, EnumIter)]
pub enum ModelProviders {
ORT(ORTProvider),
Expand Down
Loading

0 comments on commit 92cddd2

Please sign in to comment.