diff --git a/Cargo.toml b/Cargo.toml index c10986594..2fbc80846 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,9 +22,9 @@ edition = "2021" [workspace.dependencies] # AI -candle = { package = "candle-core", version = "0.7.1" } -candle-nn = { package = "candle-nn", version = "0.7.1" } -candle-transformers = { package = "candle-transformers", version = "0.7.1" } +candle = { package = "candle-core", git = "https://github.com/huggingface/candle.git", branch = "main" } +candle-nn = { package = "candle-nn", git = "https://github.com/huggingface/candle.git", branch = "main" } +candle-transformers = { package = "candle-transformers", git = "https://github.com/huggingface/candle.git", branch = "main" } tokenizers = "0.20.0" hf-hub = "0.3.0" diff --git a/screenpipe-audio/Cargo.toml b/screenpipe-audio/Cargo.toml index 069fed786..44159f7e7 100644 --- a/screenpipe-audio/Cargo.toml +++ b/screenpipe-audio/Cargo.toml @@ -34,8 +34,9 @@ chrono = { version = "0.4.31", features = ["serde"] } candle = { workspace = true } candle-nn = { workspace = true } candle-transformers = { workspace = true } -vad-rs = "0.1.3" tokenizers = { workspace = true } + +vad-rs = "0.1.3" anyhow = "1.0.86" byteorder = "1.5.0" hf-hub = "0.3.2" diff --git a/screenpipe-core/src/candle_utils.rs b/screenpipe-core/src/candle_utils.rs new file mode 100644 index 000000000..f3c0af0e4 --- /dev/null +++ b/screenpipe-core/src/candle_utils.rs @@ -0,0 +1,5 @@ +use candle::Device; + +pub fn get_device() -> Device { + Device::new_metal(0).unwrap_or(Device::new_cuda(0).unwrap_or(Device::Cpu)) +} diff --git a/screenpipe-core/src/lib.rs b/screenpipe-core/src/lib.rs index 374b2be38..71068b4ce 100644 --- a/screenpipe-core/src/lib.rs +++ b/screenpipe-core/src/lib.rs @@ -28,3 +28,6 @@ pub use pipes::*; pub mod pii_removal; #[cfg(feature = "security")] pub use pii_removal::*; + +pub mod candle_utils; +pub use candle_utils::*; diff --git a/screenpipe-vision/Cargo.toml b/screenpipe-vision/Cargo.toml index 56e2a5c2a..aae79daf4 100644 --- a/screenpipe-vision/Cargo.toml +++ b/screenpipe-vision/Cargo.toml @@ -34,10 +34,23 @@ clap = { version = "4.0", features = ["derive"] } # Integrations screenpipe-integrations = { path = "../screenpipe-integrations" } +screenpipe-core = { path = "../screenpipe-core" } tracing-subscriber = { workspace = true } tracing = { workspace = true } +candle = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } +tokenizers = { workspace = true } +hf-hub = { workspace = true, features = ["tokio"] } + +[features] +metal = ["candle/metal", "candle-nn/metal", "candle-transformers/metal"] +cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] +mkl = ["candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"] + + [dev-dependencies] tempfile = "3.3.0" criterion = { workspace = true } diff --git a/screenpipe-vision/src/lib.rs b/screenpipe-vision/src/lib.rs index 92166859e..83eb2821a 100644 --- a/screenpipe-vision/src/lib.rs +++ b/screenpipe-vision/src/lib.rs @@ -14,3 +14,4 @@ pub mod capture_screenshot_by_window; #[cfg(target_os = "windows")] pub use microsoft::perform_ocr_windows; pub use tesseract::perform_ocr_tesseract; +pub mod multimodal_embeddings; diff --git a/screenpipe-vision/src/multimodal_embeddings.rs b/screenpipe-vision/src/multimodal_embeddings.rs new file mode 100644 index 000000000..f56623db4 --- /dev/null +++ b/screenpipe-vision/src/multimodal_embeddings.rs @@ -0,0 +1,116 @@ +use std::ops::Mul; + +use anyhow::Result; +use candle::{DType, Device, Tensor}; +use candle_nn::{ops::softmax, VarBuilder}; +use candle_transformers::models::siglip::{Config, Model as SiglipModel}; +use image::DynamicImage; +use tokenizers::Tokenizer; + +pub struct MultimodalEmbedder { + model: SiglipModel, + tokenizer: Tokenizer, + device: Device, + config: Config, +} + +impl MultimodalEmbedder { + pub fn new(device: &Device) -> Result { + let config = Config::base_patch16_224(); + + // Load the model weights from safetensors file + let model_file = { + let api = hf_hub::api::sync::Api::new()?; + let api = api.model("google/siglip-base-patch16-224".to_string()); + api.get("model.safetensors")? + }; + + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, device)? }; + + let model = SiglipModel::new(&config, vb)?; + let tokenizer = Self::get_tokenizer(None)?; + + Ok(Self { + model, + tokenizer, + device: device.clone(), + config, + }) + } + + fn get_tokenizer(tokenizer_path: Option) -> Result { + let tokenizer_path = match tokenizer_path { + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.model("google/siglip-base-patch16-224".to_string()); + api.get("tokenizer.json")? + } + Some(path) => path.into(), + }; + + Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg) + } + + pub fn compute_embeddings( + &self, + image: &DynamicImage, + ocr_text: &str, + ) -> Result<(Tensor, Tensor)> { + let image_tensor = self.preprocess_image(image)?; + let text_tensor = self.tokenize_text(ocr_text)?; + + let (text_embeddings, image_embeddings) = + self.model.forward(&image_tensor, &text_tensor)?; + Ok((text_embeddings, image_embeddings)) + } + + pub fn compute_similarity( + &self, + text_embeddings: &Tensor, + image_embeddings: &Tensor, + ) -> anyhow::Result { + // compute dot product between text and image embeddings + let similarity = text_embeddings.matmul(&image_embeddings.transpose(0, 1)?)?; + + // apply softmax to get probabilities + let similarity = softmax(&similarity, 1)?; + + Ok(similarity) + } + + fn preprocess_image(&self, image: &DynamicImage) -> Result { + let image_size = self.config.vision_config.image_size; + let img = image.resize_to_fill( + image_size as u32, + image_size as u32, + image::imageops::FilterType::Triangle, + ); + let img = img.to_rgb8(); + let img = img.into_raw(); + let img = Tensor::from_vec(img, (image_size, image_size, 3), &self.device)? + .permute((2, 0, 1))? + .to_dtype(DType::F32)? + .affine(2. / 255., -1.)? + .unsqueeze(0)?; + Ok(img) + } + + fn tokenize_text(&self, text: &str) -> anyhow::Result { + let encoding = self + .tokenizer + .encode(text, true) + .map_err(|e| anyhow::anyhow!(e))?; + let mut tokens = encoding.get_ids().to_vec(); + let max_len = self.config.text_config.max_position_embeddings; + let pad_id = self.config.text_config.pad_token_id; + + // Pad the sequence to have the correct length + let len_diff = max_len - tokens.len(); + if len_diff > 0 { + tokens.extend(vec![pad_id; len_diff]); + } + + let input_ids = Tensor::new(vec![tokens], &self.device)?; + Ok(input_ids) + } +} diff --git a/screenpipe-vision/tests/embedding_benchmark.rs b/screenpipe-vision/tests/embedding_benchmark.rs new file mode 100644 index 000000000..6511a941f --- /dev/null +++ b/screenpipe-vision/tests/embedding_benchmark.rs @@ -0,0 +1,45 @@ +use anyhow::Result; +use candle::Device; +use image::DynamicImage; +use screenpipe_core::get_device; +use screenpipe_vision::multimodal_embeddings::MultimodalEmbedder; +use std::time::Instant; + +// Mock function to simulate screenshot capture +fn capture_screenshot() -> Result { + // For this test, we'll create a dummy image + let img = DynamicImage::new_rgb8(224, 224); + Ok(img) +} + +#[test] +fn test_screenshot_and_embedding_speed() -> Result<()> { + let device = get_device(); + let embedder = MultimodalEmbedder::new(&device).unwrap(); + + let start = Instant::now(); + + // Capture screenshot + let screenshot = capture_screenshot()?; + let screenshot_time = start.elapsed(); + + // Perform OCR (mocked for this test) + let ocr_text = "This is a test OCR text"; + + // Compute embeddings + let embedding_start = Instant::now(); + let (text_embeddings, image_embeddings) = embedder.compute_embeddings(&screenshot, ocr_text)?; + let embedding_time = embedding_start.elapsed(); + + // Compute similarity + let similarity = embedder.compute_similarity(&text_embeddings, &image_embeddings)?; + + let total_time = start.elapsed(); + + println!("Screenshot capture time: {:?}", screenshot_time); + println!("Embedding computation time: {:?}", embedding_time); + println!("Total processing time: {:?}", total_time); + println!("Similarity shape: {:?}", similarity.shape()); + + Ok(()) +}