Skip to content

Commit

Permalink
exp
Browse files Browse the repository at this point in the history
  • Loading branch information
louis030195 committed Sep 29, 2024
1 parent 874716b commit 1a26110
Show file tree
Hide file tree
Showing 8 changed files with 188 additions and 4 deletions.
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ edition = "2021"

[workspace.dependencies]
# AI
candle = { package = "candle-core", version = "0.7.1" }
candle-nn = { package = "candle-nn", version = "0.7.1" }
candle-transformers = { package = "candle-transformers", version = "0.7.1" }
candle = { package = "candle-core", git = "https://github.com/huggingface/candle.git", branch = "main" }
candle-nn = { package = "candle-nn", git = "https://github.com/huggingface/candle.git", branch = "main" }
candle-transformers = { package = "candle-transformers", git = "https://github.com/huggingface/candle.git", branch = "main" }
tokenizers = "0.20.0"
hf-hub = "0.3.0"

Expand Down
3 changes: 2 additions & 1 deletion screenpipe-audio/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ chrono = { version = "0.4.31", features = ["serde"] }
candle = { workspace = true }
candle-nn = { workspace = true }
candle-transformers = { workspace = true }
vad-rs = "0.1.3"
tokenizers = { workspace = true }

vad-rs = "0.1.3"
anyhow = "1.0.86"
byteorder = "1.5.0"
hf-hub = "0.3.2"
Expand Down
5 changes: 5 additions & 0 deletions screenpipe-core/src/candle_utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
use candle::Device;

pub fn get_device() -> Device {
Device::new_metal(0).unwrap_or(Device::new_cuda(0).unwrap_or(Device::Cpu))
}
3 changes: 3 additions & 0 deletions screenpipe-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@ pub use pipes::*;
pub mod pii_removal;
#[cfg(feature = "security")]
pub use pii_removal::*;

pub mod candle_utils;
pub use candle_utils::*;
13 changes: 13 additions & 0 deletions screenpipe-vision/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,23 @@ clap = { version = "4.0", features = ["derive"] }

# Integrations
screenpipe-integrations = { path = "../screenpipe-integrations" }
screenpipe-core = { path = "../screenpipe-core" }

tracing-subscriber = { workspace = true }
tracing = { workspace = true }

candle = { workspace = true }
candle-nn = { workspace = true }
candle-transformers = { workspace = true }
tokenizers = { workspace = true }
hf-hub = { workspace = true, features = ["tokio"] }

[features]
metal = ["candle/metal", "candle-nn/metal", "candle-transformers/metal"]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
mkl = ["candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]


[dev-dependencies]
tempfile = "3.3.0"
criterion = { workspace = true }
Expand Down
1 change: 1 addition & 0 deletions screenpipe-vision/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ pub mod capture_screenshot_by_window;
#[cfg(target_os = "windows")]
pub use microsoft::perform_ocr_windows;
pub use tesseract::perform_ocr_tesseract;
pub mod multimodal_embeddings;
116 changes: 116 additions & 0 deletions screenpipe-vision/src/multimodal_embeddings.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
use std::ops::Mul;

use anyhow::Result;
use candle::{DType, Device, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use candle_transformers::models::siglip::{Config, Model as SiglipModel};
use image::DynamicImage;
use tokenizers::Tokenizer;

pub struct MultimodalEmbedder {
model: SiglipModel,
tokenizer: Tokenizer,
device: Device,
config: Config,
}

impl MultimodalEmbedder {
pub fn new(device: &Device) -> Result<Self> {
let config = Config::base_patch16_224();

// Load the model weights from safetensors file
let model_file = {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("google/siglip-base-patch16-224".to_string());
api.get("model.safetensors")?
};

let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, device)? };

let model = SiglipModel::new(&config, vb)?;
let tokenizer = Self::get_tokenizer(None)?;

Ok(Self {
model,
tokenizer,
device: device.clone(),
config,
})
}

fn get_tokenizer(tokenizer_path: Option<String>) -> Result<Tokenizer> {
let tokenizer_path = match tokenizer_path {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("google/siglip-base-patch16-224".to_string());
api.get("tokenizer.json")?
}
Some(path) => path.into(),
};

Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)
}

pub fn compute_embeddings(
&self,
image: &DynamicImage,
ocr_text: &str,
) -> Result<(Tensor, Tensor)> {
let image_tensor = self.preprocess_image(image)?;
let text_tensor = self.tokenize_text(ocr_text)?;

let (text_embeddings, image_embeddings) =
self.model.forward(&image_tensor, &text_tensor)?;
Ok((text_embeddings, image_embeddings))
}

pub fn compute_similarity(
&self,
text_embeddings: &Tensor,
image_embeddings: &Tensor,
) -> anyhow::Result<Tensor> {
// compute dot product between text and image embeddings
let similarity = text_embeddings.matmul(&image_embeddings.transpose(0, 1)?)?;

// apply softmax to get probabilities
let similarity = softmax(&similarity, 1)?;

Ok(similarity)
}

fn preprocess_image(&self, image: &DynamicImage) -> Result<Tensor> {
let image_size = self.config.vision_config.image_size;
let img = image.resize_to_fill(
image_size as u32,
image_size as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8();
let img = img.into_raw();
let img = Tensor::from_vec(img, (image_size, image_size, 3), &self.device)?
.permute((2, 0, 1))?
.to_dtype(DType::F32)?
.affine(2. / 255., -1.)?
.unsqueeze(0)?;
Ok(img)
}

fn tokenize_text(&self, text: &str) -> anyhow::Result<Tensor> {
let encoding = self
.tokenizer
.encode(text, true)
.map_err(|e| anyhow::anyhow!(e))?;
let mut tokens = encoding.get_ids().to_vec();
let max_len = self.config.text_config.max_position_embeddings;
let pad_id = self.config.text_config.pad_token_id;

// Pad the sequence to have the correct length
let len_diff = max_len - tokens.len();
if len_diff > 0 {
tokens.extend(vec![pad_id; len_diff]);
}

let input_ids = Tensor::new(vec![tokens], &self.device)?;
Ok(input_ids)
}
}
45 changes: 45 additions & 0 deletions screenpipe-vision/tests/embedding_benchmark.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use anyhow::Result;
use candle::Device;
use image::DynamicImage;
use screenpipe_core::get_device;
use screenpipe_vision::multimodal_embeddings::MultimodalEmbedder;
use std::time::Instant;

// Mock function to simulate screenshot capture
fn capture_screenshot() -> Result<DynamicImage> {
// For this test, we'll create a dummy image
let img = DynamicImage::new_rgb8(224, 224);
Ok(img)
}

#[test]
fn test_screenshot_and_embedding_speed() -> Result<()> {
let device = get_device();
let embedder = MultimodalEmbedder::new(&device).unwrap();

let start = Instant::now();

// Capture screenshot
let screenshot = capture_screenshot()?;
let screenshot_time = start.elapsed();

// Perform OCR (mocked for this test)
let ocr_text = "This is a test OCR text";

// Compute embeddings
let embedding_start = Instant::now();
let (text_embeddings, image_embeddings) = embedder.compute_embeddings(&screenshot, ocr_text)?;
let embedding_time = embedding_start.elapsed();

// Compute similarity
let similarity = embedder.compute_similarity(&text_embeddings, &image_embeddings)?;

let total_time = start.elapsed();

println!("Screenshot capture time: {:?}", screenshot_time);
println!("Embedding computation time: {:?}", embedding_time);
println!("Total processing time: {:?}", total_time);
println!("Similarity shape: {:?}", similarity.shape());

Ok(())
}

0 comments on commit 1a26110

Please sign in to comment.