diff --git a/ahnlich/Cargo.lock b/ahnlich/Cargo.lock index 5061ab30..0edf37d3 100644 --- a/ahnlich/Cargo.lock +++ b/ahnlich/Cargo.lock @@ -46,6 +46,7 @@ dependencies = [ "pretty_assertions", "thiserror", "tokio", + "typed-builder", "utils", ] @@ -3909,6 +3910,26 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "typed-builder" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e14ed59dc8b7b26cacb2a92bad2e8b1f098806063898ab42a3bd121d7d45e75" +dependencies = [ + "typed-builder-macro", +] + +[[package]] +name = "typed-builder-macro" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "560b82d656506509d43abe30e0ba64c56b1953ab3d4fe7ba5902747a7a3cedd5" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "typegen" version = "0.1.0" diff --git a/ahnlich/client/Cargo.toml b/ahnlich/client/Cargo.toml index c2ea27a6..82b68031 100644 --- a/ahnlich/client/Cargo.toml +++ b/ahnlich/client/Cargo.toml @@ -29,3 +29,4 @@ ai = { path = "../ai", version = "*" } pretty_assertions.workspace = true ndarray.workspace = true utils = { path = "../utils", version = "*" } +typed-builder = "0.20.0" diff --git a/ahnlich/client/src/ai.rs b/ahnlich/client/src/ai.rs index b4b68f4f..e208d0e7 100644 --- a/ahnlich/client/src/ai.rs +++ b/ahnlich/client/src/ai.rs @@ -1,3 +1,4 @@ +use crate::builders::ai as ai_params; use crate::conn::{AIConn, Connection}; use crate::error::AhnlichError; use crate::prelude::*; @@ -211,148 +212,138 @@ impl AIClient { )) } - #[allow(clippy::too_many_arguments)] pub async fn create_store( &self, - store: StoreName, - query_model: AIModel, - index_model: AIModel, - predicates: HashSet, - non_linear_indices: HashSet, - error_if_exists: bool, - store_original: bool, - tracing_id: Option, + store_params: ai_params::CreateStoreParams, ) -> Result { self.exec( AIQuery::CreateStore { - store, - query_model, - index_model, - predicates, - non_linear_indices, - error_if_exists, - store_original, + store: store_params.store, + query_model: store_params.query_model, + index_model: store_params.index_model, + predicates: store_params.predicates, + non_linear_indices: store_params.non_linear_indices, + error_if_exists: store_params.error_if_exists, + store_original: store_params.store_original, }, - tracing_id, + store_params.tracing_id, ) .await } pub async fn get_pred( &self, - store: StoreName, - condition: PredicateCondition, - tracing_id: Option, + params: ai_params::GetPredParams, ) -> Result { - self.exec(AIQuery::GetPred { store, condition }, tracing_id) - .await + self.exec( + AIQuery::GetPred { + store: params.store, + condition: params.condition, + }, + params.tracing_id, + ) + .await } pub async fn get_sim_n( &self, - store: StoreName, - search_input: StoreInput, - condition: Option, - closest_n: NonZeroUsize, - algorithm: Algorithm, - tracing_id: Option, + params: ai_params::GetSimNParams, ) -> Result { self.exec( AIQuery::GetSimN { - store, - search_input, - condition, - closest_n, - algorithm, + store: params.store, + search_input: params.search_input, + condition: params.condition, + closest_n: params.closest_n, + algorithm: params.algorithm, }, - tracing_id, + params.tracing_id, ) .await } pub async fn create_pred_index( &self, - store: StoreName, - predicates: HashSet, - tracing_id: Option, + params: ai_params::CreatePredIndexParams, ) -> Result { - self.exec(AIQuery::CreatePredIndex { store, predicates }, tracing_id) - .await + self.exec( + AIQuery::CreatePredIndex { + store: params.store, + predicates: params.predicates, + }, + params.tracing_id, + ) + .await } pub async fn create_non_linear_algorithm_index( &self, - store: StoreName, - non_linear_indices: HashSet, - tracing_id: Option, + params: ai_params::CreateNonLinearAlgorithmIndexParams, ) -> Result { self.exec( AIQuery::CreateNonLinearAlgorithmIndex { - store, - non_linear_indices, + store: params.store, + non_linear_indices: params.non_linear_indices, }, - tracing_id, + params.tracing_id, ) .await } pub async fn drop_pred_index( &self, - store: StoreName, - predicates: HashSet, - error_if_not_exists: bool, - tracing_id: Option, + params: ai_params::DropPredIndexParams, ) -> Result { self.exec( AIQuery::DropPredIndex { - store, - predicates, - error_if_not_exists, + store: params.store, + predicates: params.predicates, + error_if_not_exists: params.error_if_not_exists, }, - tracing_id, + params.tracing_id, ) .await } pub async fn set( &self, - store: StoreName, - inputs: Vec<(StoreInput, StoreValue)>, - preprocess_action: PreprocessAction, - tracing_id: Option, + params: ai_params::SetParams, ) -> Result { self.exec( AIQuery::Set { - store, - inputs, - preprocess_action, + store: params.store, + inputs: params.inputs, + preprocess_action: params.preprocess_action, }, - tracing_id, + params.tracing_id, ) .await } pub async fn del_key( &self, - store: StoreName, - key: StoreInput, - tracing_id: Option, + params: ai_params::DelKeyParams, ) -> Result { - self.exec(AIQuery::DelKey { store, key }, tracing_id).await + self.exec( + AIQuery::DelKey { + store: params.store, + key: params.key, + }, + params.tracing_id, + ) + .await } pub async fn drop_store( &self, - store: StoreName, - error_if_not_exists: bool, - tracing_id: Option, + params: ai_params::DropStoreParams, ) -> Result { self.exec( AIQuery::DropStore { - store, - error_if_not_exists, + store: params.store, + error_if_not_exists: params.error_if_not_exists, }, - tracing_id, + params.tracing_id, ) .await } diff --git a/ahnlich/client/src/builders/ai.rs b/ahnlich/client/src/builders/ai.rs new file mode 100644 index 00000000..b481f015 --- /dev/null +++ b/ahnlich/client/src/builders/ai.rs @@ -0,0 +1,129 @@ +use std::{collections::HashSet, num::NonZeroUsize}; +use typed_builder::TypedBuilder; + +use ahnlich_types::{ + ai::{AIModel, PreprocessAction}, + keyval::{StoreInput, StoreName, StoreValue}, + metadata::MetadataKey, + predicate::PredicateCondition, + similarity::{Algorithm, NonLinearAlgorithm}, +}; + +#[derive(TypedBuilder)] +pub struct CreateStoreParams { + #[builder(setter(into, transform = |s: String| StoreName(s)))] + pub store: StoreName, + + #[builder(default = AIModel::AllMiniLML6V2)] + pub query_model: AIModel, + + #[builder(default = AIModel::AllMiniLML6V2)] + pub index_model: AIModel, + + #[builder(default = HashSet::new())] + pub predicates: HashSet, + + #[builder(default = HashSet::new())] + pub non_linear_indices: HashSet, + + #[builder(default = true)] + pub error_if_exists: bool, + + #[builder(default = true)] + pub store_original: bool, + + pub tracing_id: Option, +} + +#[derive(TypedBuilder)] +pub struct GetPredParams { + #[builder(setter(into, transform = |s: String| StoreName(s)))] + pub store: StoreName, + pub condition: PredicateCondition, + pub tracing_id: Option, +} + +#[derive(TypedBuilder)] +pub struct GetSimNParams { + #[builder(setter(into, transform = |s: String| StoreName(s)))] + pub store: StoreName, + pub search_input: StoreInput, + pub condition: Option, + + #[builder(default=NonZeroUsize::new(1).unwrap())] + pub closest_n: NonZeroUsize, + + #[builder(default=Algorithm::CosineSimilarity)] + pub algorithm: Algorithm, + + pub tracing_id: Option, +} + +#[derive(TypedBuilder)] +pub struct CreatePredIndexParams { + #[builder(setter(into, transform = |s: String| StoreName(s)))] + pub store: StoreName, + + //#[builder(default = HashSet::new())] + pub predicates: HashSet, + + pub tracing_id: Option, +} + +#[derive(TypedBuilder)] +pub struct CreateNonLinearAlgorithmIndexParams { + #[builder(setter(into, transform = |s: String| StoreName(s)))] + pub store: StoreName, + + #[builder(default = HashSet::from_iter(&[NonLinearAlgorithm::KDTree]))] + pub non_linear_indices: HashSet, + + pub tracing_id: Option, +} + +#[derive(TypedBuilder)] +pub struct DropPredIndexParams { + #[builder(setter(into, transform = |s: String| StoreName(s)))] + pub store: StoreName, + + //#[builder(default = HashSet::new())] + pub predicates: HashSet, + + #[builder(default = true)] + pub error_if_not_exists: bool, + + pub tracing_id: Option, +} + +#[derive(TypedBuilder)] +pub struct SetParams { + #[builder(setter(into, transform = |s: String| StoreName(s)))] + pub store: StoreName, + + pub inputs: Vec<(StoreInput, StoreValue)>, + + #[builder(default = PreprocessAction::NoPreprocessing)] + pub preprocess_action: PreprocessAction, + + pub tracing_id: Option, +} + +#[derive(TypedBuilder)] +pub struct DelKeyParams { + #[builder(setter(into, transform = |s: String| StoreName(s)))] + pub store: StoreName, + + pub key: StoreInput, + pub tracing_id: Option, +} + +#[derive(TypedBuilder)] +pub struct DropStoreParams { + #[builder(setter(into, transform = |s: String| StoreName(s)))] + pub store: StoreName, + + #[builder(default = true)] + pub error_if_not_exists: bool, + + pub tracing_id: Option, +} diff --git a/ahnlich/client/src/builders/db.rs b/ahnlich/client/src/builders/db.rs new file mode 100644 index 00000000..e69de29b diff --git a/ahnlich/client/src/builders/mod.rs b/ahnlich/client/src/builders/mod.rs new file mode 100644 index 00000000..d97acb99 --- /dev/null +++ b/ahnlich/client/src/builders/mod.rs @@ -0,0 +1,2 @@ +pub mod ai; +pub mod db; diff --git a/ahnlich/client/src/lib.rs b/ahnlich/client/src/lib.rs index d12c31c0..15c3fef1 100644 --- a/ahnlich/client/src/lib.rs +++ b/ahnlich/client/src/lib.rs @@ -113,6 +113,7 @@ //! let results = pipeline.exec().await.unwrap(); //! ``` pub mod ai; +pub mod builders; pub mod conn; pub mod db; pub mod error;