Skip to content

Commit

Permalink
Set up ORT for Text, ORT for Image yet to run
Browse files Browse the repository at this point in the history
  • Loading branch information
HAKSOAT committed Nov 24, 2024
1 parent 40cb9ce commit 0a10b48
Show file tree
Hide file tree
Showing 28 changed files with 855 additions and 748 deletions.
1 change: 1 addition & 0 deletions ahnlich/ai/src/cli/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ impl Default for AIProxyConfig {
SupportedModels::AllMiniLML6V2,
SupportedModels::AllMiniLML12V2,
SupportedModels::BGEBaseEnV15,
SupportedModels::BGELargeEnV15,
SupportedModels::ClipVitB32Text,
SupportedModels::Resnet50,
SupportedModels::ClipVitB32Image,
Expand Down
55 changes: 21 additions & 34 deletions ahnlich/ai/src/engine/ai/models.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::cli::server::SupportedModels;
use crate::engine::ai::providers::fastembed::FastEmbedProvider;
use crate::engine::ai::providers::ort::ORTProvider;
use crate::engine::ai::providers::ModelProviders;
use crate::engine::ai::providers::ProviderTrait;
Expand All @@ -20,6 +19,7 @@ use std::path::Path;
use strum::Display;
use serde::ser::Error as SerError;
use serde::de::Error as DeError;
use tokenizers::Encoding;

#[derive(Display)]
pub enum ModelType {
Expand Down Expand Up @@ -47,7 +47,7 @@ impl From<&AIModel> for Model {
model_type: ModelType::Text {
max_input_tokens: nonzero!(256usize),
},
provider: ModelProviders::FastEmbed(FastEmbedProvider::new()),
provider: ModelProviders::ORT(ORTProvider::new()),
supported_model: SupportedModels::AllMiniLML6V2,
description: String::from("Sentence Transformer model, with 6 layers, version 2"),
embedding_size: nonzero!(384usize),
Expand All @@ -57,7 +57,7 @@ impl From<&AIModel> for Model {
// Token size source: https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2#intended-uses
max_input_tokens: nonzero!(256usize),
},
provider: ModelProviders::FastEmbed(FastEmbedProvider::new()),
provider: ModelProviders::ORT(ORTProvider::new()),
supported_model: SupportedModels::AllMiniLML12V2,
description: String::from("Sentence Transformer model, with 12 layers, version 2."),
embedding_size: nonzero!(384usize),
Expand All @@ -67,7 +67,7 @@ impl From<&AIModel> for Model {
// Token size source: https://huggingface.co/BAAI/bge-large-en/discussions/11#64e44de1623074ac850aa1ae
max_input_tokens: nonzero!(512usize),
},
provider: ModelProviders::FastEmbed(FastEmbedProvider::new()),
provider: ModelProviders::ORT(ORTProvider::new()),
supported_model: SupportedModels::BGEBaseEnV15,
description: String::from(
"BAAI General Embedding model with English support, base scale, version 1.5.",
Expand All @@ -78,7 +78,7 @@ impl From<&AIModel> for Model {
model_type: ModelType::Text {
max_input_tokens: nonzero!(512usize),
},
provider: ModelProviders::FastEmbed(FastEmbedProvider::new()),
provider: ModelProviders::ORT(ORTProvider::new()),
supported_model: SupportedModels::BGELargeEnV15,
description: String::from(
"BAAI General Embedding model with English support, large scale, version 1.5.",
Expand Down Expand Up @@ -134,14 +134,11 @@ impl Model {
#[tracing::instrument(skip(self))]
pub fn model_ndarray(
&self,
storeinput: Vec<ModelInput>,
modelinput: ModelInput,
action_type: &InputAction,
) -> Result<Vec<StoreKey>, AIProxyError> {
let store_keys = match &self.provider {
ModelProviders::FastEmbed(provider) => {
provider.run_inference(storeinput, action_type)?
}
ModelProviders::ORT(provider) => provider.run_inference(storeinput, action_type)?,
ModelProviders::ORT(provider) => provider.run_inference(modelinput, action_type)?,
};
Ok(store_keys)
}
Expand Down Expand Up @@ -173,10 +170,6 @@ impl Model {
pub fn setup_provider(&mut self, cache_location: &Path) {
let supported_model = self.supported_model;
match &mut self.provider {
ModelProviders::FastEmbed(provider) => {
provider.set_model(&supported_model);
provider.set_cache_location(cache_location);
}
ModelProviders::ORT(provider) => {
provider.set_model(&supported_model);
provider.set_cache_location(cache_location);
Expand All @@ -186,9 +179,6 @@ impl Model {

pub fn load(&mut self) -> Result<(), AIProxyError> {
match &mut self.provider {
ModelProviders::FastEmbed(provider) => {
provider.load_model()?;
}
ModelProviders::ORT(provider) => {
provider.load_model()?;
}
Expand All @@ -198,9 +188,6 @@ impl Model {

pub fn get(&self) -> Result<(), AIProxyError> {
match &self.provider {
ModelProviders::FastEmbed(provider) => {
provider.get_model()?;
}
ModelProviders::ORT(provider) => {
provider.get_model()?;
}
Expand Down Expand Up @@ -258,8 +245,8 @@ impl fmt::Display for InputAction {

#[derive(Debug)]
pub enum ModelInput {
Text(String),
Image(ImageArray),
Texts(Vec<Encoding>),
Images(Vec<ImageArray>),
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -380,22 +367,22 @@ impl<'de> Deserialize<'de> for ImageArray {
}
}

impl TryFrom<StoreInput> for ModelInput {
type Error = AIProxyError;

fn try_from(value: StoreInput) -> Result<Self, Self::Error> {
match value {
StoreInput::RawString(s) => Ok(ModelInput::Text(s)),
StoreInput::Image(bytes) => Ok(ModelInput::Image(ImageArray::try_new(bytes)?)),
}
}
}
// impl TryFrom<StoreInput> for ModelInput {
// type Error = AIProxyError;
//
// fn try_from(value: StoreInput) -> Result<Self, Self::Error> {
// match value {
// StoreInput::RawString(s) => Ok(ModelInput::Text(s)),
// StoreInput::Image(bytes) => Ok(ModelInput::Image(ImageArray::try_new(bytes)?)),
// }
// }
// }

impl From<&ModelInput> for AIStoreInputType {
fn from(value: &ModelInput) -> AIStoreInputType {
match value {
ModelInput::Text(_) => AIStoreInputType::RawString,
ModelInput::Image(_) => AIStoreInputType::Image,
ModelInput::Texts(_) => AIStoreInputType::RawString,
ModelInput::Images(_) => AIStoreInputType::Image,
}
}
}
219 changes: 0 additions & 219 deletions ahnlich/ai/src/engine/ai/providers/fastembed.rs

This file was deleted.

Loading

0 comments on commit 0a10b48

Please sign in to comment.