advice & tips on indexing using multimodal embeddings 24h of screen recording on consumer hardware (100,000 frames/day) #2557
Replies: 3 comments 3 replies
-
Using clip B/32 and autofaiss you can do all that in 10min and 40MB of storage with 3080 That said I think you should maybe consider https://blog.vespa.ai/retrieval-with-vision-language-models-colpali/ . It's probably better than clip for document retrieval |
Beta Was this translation helpful? Give feedback.
-
There was a suggestion to use CoPali for it's improved doc retrieval abilities from another person. Wanted to build on that for a sec. Also on the plus side it would also eliminate most of your need for OCR in pipeline. But to your point right now CoPali is far too slow. At least a naive implementation would be for your use-case. What if you were to use a very lightweight image embedding model to perform a deduplication step before processing frames? I'd bet most of those frames are duplicates or have substantial informational overlap. So that would reduce total # of inferences significantly id imagine. Currently when I ran ColPali locally on CPU with no optimizations (i.e. no MKL, no accelerate) it took about 1 minute / page. I'd assume a fair speedup with the above optimizations but I haven't tried them yet personally. Also, originally this model was designed for retrieving complex documents. I don't believe the current implementation in the candle repo has any quantization applied. That makes sense to maintain visual fidelity for complex docs but for your use case you might be able to get away with a fairly quantized model. Anywho just a few thoughts. Interesting problem space and I'm a huge fan of your project. All the best! |
Beta Was this translation helpful? Give feedback.
-
"""
virtualenv env
source env/bin/activate
pip install torch transformers autofaiss opencv-python-headless pillow fire
"""
import os
import sqlite3
import numpy as np
import faiss
from autofaiss import build_index
import cv2
from transformers import CLIPProcessor, CLIPModel
import torch
from tqdm import tqdm
import fire
import time
from collections import defaultdict
# Set up paths
HOME = os.path.expanduser("~")
DB_PATH = os.path.join(HOME, ".screenpipe", "db.sqlite")
INDEX_PATH = os.path.join(HOME, ".screenpipe", "faiss_index")
BATCH_SIZE = 1000 # Adjust this based on your available memory
# Check if MPS is available
if torch.backends.mps.is_available():
device = torch.device("mps")
print("Using MPS (Metal) backend")
else:
device = torch.device("cpu")
print("MPS not available, using CPU")
# Initialize CLIP model for text embeddings
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Function to extract frame from video
def extract_frame(video_path, frame_number=0):
cap = cv2.VideoCapture(video_path)
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
ret, frame = cap.read()
cap.release()
return frame if ret else None
def get_text_embedding(text):
inputs = clip_processor(text=text, return_tensors="pt", padding=True, truncation=True, max_length=77)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
text_features = clip_model.get_text_features(**inputs)
return text_features.cpu().numpy().flatten()
def process_batch(batch):
embeddings = []
metadata = []
for row in batch:
frame_id, timestamp, ocr_text, transcription = row
combined_text = f"{ocr_text or ''} {transcription or ''}".strip()
if combined_text:
embedding = get_text_embedding(combined_text)
embeddings.append(embedding)
metadata.append({
"frame_id": frame_id,
"timestamp": timestamp,
"text": combined_text
})
return np.array(embeddings).astype('float32'), metadata
def create_index(embeddings, metadata):
index, _ = build_index(
embeddings,
save_on_disk=True,
index_path=INDEX_PATH,
index_infos_path=f"{INDEX_PATH}_infos.json",
metric_type="ip"
)
return index
def text_search(query_text, index, metadata, k=5):
query_embedding = get_text_embedding(query_text)
D, I = index.search(query_embedding.reshape(1, -1), k)
results = []
for i in range(k):
results.append({
"score": float(D[0][i]),
"frame_id": metadata[I[0][i]]["frame_id"],
"timestamp": metadata[I[0][i]]["timestamp"],
"text": metadata[I[0][i]]["text"]
})
return results
class ScreenpipeSearch:
def __init__(self):
self.device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {self.device}")
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
self.index = None
self.metadata = None
self.timings = defaultdict(float)
def get_text_embedding(self, text):
inputs = self.clip_processor(text=text, return_tensors="pt", padding=True, truncation=True, max_length=77)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
text_features = self.clip_model.get_text_features(**inputs)
return text_features.cpu().numpy().flatten()
def time_function(self, func, *args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
func_name = func.__name__
self.timings[func_name] += end_time - start_time
return result
def process_batch(self, batch):
return self.time_function(self._process_batch, batch)
def _process_batch(self, batch):
embeddings = []
metadata = []
for row in batch:
frame_id, timestamp, ocr_text, transcription = row
combined_text = f"{ocr_text or ''} {transcription or ''}".strip()
if combined_text:
embedding = self.get_text_embedding(combined_text)
embeddings.append(embedding)
metadata.append({
"frame_id": frame_id,
"timestamp": timestamp,
"text": combined_text
})
return np.array(embeddings).astype('float32'), metadata
def build(self):
start_time = time.time()
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute("""
SELECT COUNT(*)
FROM frames f
LEFT JOIN ocr_text o ON f.id = o.frame_id
LEFT JOIN audio_transcriptions a ON f.timestamp = a.timestamp
""")
total_rows = cursor.fetchone()[0]
all_embeddings = []
all_metadata = []
cursor.execute("""
SELECT f.id, f.timestamp, o.text AS ocr_text, a.transcription
FROM frames f
LEFT JOIN ocr_text o ON f.id = o.frame_id
LEFT JOIN audio_transcriptions a ON f.timestamp = a.timestamp
ORDER BY f.timestamp
""")
all_data = cursor.fetchall()
conn.close()
self.timings['database_query'] = time.time() - start_time
total_rows = len(all_data)
with tqdm(total=total_rows, desc="Processing data") as pbar:
for i in range(0, total_rows, BATCH_SIZE):
batch = all_data[i:i+BATCH_SIZE]
embeddings, metadata = self.process_batch(batch)
all_embeddings.append(embeddings)
all_metadata.extend(metadata)
pbar.update(len(batch))
self.timings['data_processing'] = time.time() - start_time - self.timings['database_query']
final_embeddings = np.concatenate(all_embeddings)
index_start_time = time.time()
self.index = faiss.IndexFlatIP(final_embeddings.shape[1])
self.index.add(final_embeddings)
self.timings['indexing'] = time.time() - index_start_time
save_start_time = time.time()
faiss.write_index(self.index, INDEX_PATH)
np.save(f"{INDEX_PATH}_metadata.npy", all_metadata)
self.timings['saving'] = time.time() - save_start_time
total_time = time.time() - start_time
self.timings['total'] = total_time
print(f"Index built and saved to {INDEX_PATH}")
self.print_timings()
def print_timings(self):
print("\nTiming breakdown:")
for key, value in self.timings.items():
print(f"{key}: {value:.2f} seconds")
def estimate_total_time(self, sample_size=1000):
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) FROM frames")
total_rows = cursor.fetchone()[0]
cursor.execute(f"""
SELECT f.id, f.timestamp, o.text AS ocr_text, a.transcription
FROM frames f
LEFT JOIN ocr_text o ON f.id = o.frame_id
LEFT JOIN audio_transcriptions a ON f.timestamp = a.timestamp
ORDER BY RANDOM()
LIMIT {sample_size}
""")
sample_data = cursor.fetchall()
conn.close()
start_time = time.time()
self.process_batch(sample_data)
sample_time = time.time() - start_time
estimated_time = (sample_time / sample_size) * total_rows
print(f"\nEstimated total processing time: {estimated_time:.2f} seconds")
print(f"Estimated total processing time: {estimated_time/60:.2f} minutes")
print(f"Estimated total processing time: {estimated_time/3600:.2f} hours")
return estimated_time
def load(self):
self.index = faiss.read_index(INDEX_PATH)
with open(f"{INDEX_PATH}_metadata.npy", 'rb') as f:
self.metadata = np.load(f, allow_pickle=True).item()
print(f"Index loaded from {INDEX_PATH}")
def search(self, query_text, k=5):
if self.index is None:
self.load()
query_embedding = self.get_text_embedding(query_text)
D, I = self.index.search(query_embedding.reshape(1, -1), k)
results = [
{
"score": float(D[0][i]),
"frame_id": self.metadata[I[0][i]]["frame_id"],
"timestamp": self.metadata[I[0][i]]["timestamp"],
"text": self.metadata[I[0][i]]["text"]
}
for i in range(k)
]
return results
def build():
searcher = ScreenpipeSearch()
searcher.estimate_total_time()
searcher.build()
def search(query, k=5):
searcher = ScreenpipeSearch()
results = searcher.search(query, k)
for result in results:
print(f"Score: {result['score']}")
print(f"Frame ID: {result['frame_id']}")
print(f"Timestamp: {result['timestamp']}")
print(f"Text: {result['text']}")
print("---")
if __name__ == "__main__":
fire.Fire({
"build": build,
"search": search
}) trying this for fun takes 3h on my mac 🙁 |
Beta Was this translation helpful? Give feedback.
-
hi, i'm working on:
https://github.com/mediar-ai/screenpipe
it records your screens & mics 24/7 and extract OCR & STT into a local sqlite db
we want to explore vector search to improve search relevancy, problem is that it takes lot of resource & time and so we're considering to compute at night or when computer is lower in usage
mediar-ai/screenpipe#377
quick maths
i'm trying to do some quick maths if it's actually possible or just a fantasy at this point:
To estimate the time it would take to embed 24 hours of screen, OCR, and STT recording using batch processing, we need to make some assumptions and do some calculations. Let's break this down:
Assumptions:
Calculations:
Total frames in 24 hours:
24 hours * 60 minutes * 60 seconds = 86,400 frames
Number of batches:
86,400 frames / 60 frames per batch = 1,440 batches
Time per batch:
If a single inference takes 10 seconds, let's assume batch processing might be more efficient.
Let's estimate 15 seconds per batch of 60 frames.
Total processing time:
1,440 batches * 15 seconds = 21,600 seconds = 6 hours
Considerations:
Based on these calculations, a very rough estimate for processing 24 hours of data might be around 6 hours. However, this could be significantly reduced with optimizations:
To improve performance, you could:
Remember, the actual implementation and optimization could significantly change these estimates. It would be best to prototype the batch processing approach and measure actual performance on your target hardware.
notes
questions
thank you!
Beta Was this translation helpful? Give feedback.
All reactions