diff --git a/README.md b/README.md index 8c4e4d6..790363b 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ The default model is Flag Embedding, which is top of the [MTEB](https://huggingf - [**sentence-transformers/paraphrase-MiniLM-L12-v2**](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L12-v2) - [**sentence-transformers/paraphrase-multilingual-mpnet-base-v2**](https://huggingface.co/sentence-transformers/paraphrase-multilingual-mpnet-base-v2) - [**nomic-ai/nomic-embed-text-v1**](https://huggingface.co/nomic-ai/nomic-embed-text-v1) -- [**nomic-ai/nomic-embed-text-v1.5**](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5) +- [**nomic-ai/nomic-embed-text-v1.5**](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5) - pairs with the image model nomic-embed-vision-v1.5 for image-to-text search - [**intfloat/multilingual-e5-small**](https://huggingface.co/intfloat/multilingual-e5-small) - [**intfloat/multilingual-e5-base**](https://huggingface.co/intfloat/multilingual-e5-base) - [**intfloat/multilingual-e5-large**](https://huggingface.co/intfloat/multilingual-e5-large) @@ -59,6 +59,7 @@ The default model is Flag Embedding, which is top of the [MTEB](https://huggingf - [**Qdrant/resnet50-onnx**](https://huggingface.co/Qdrant/resnet50-onnx) - [**Qdrant/Unicom-ViT-B-16**](https://huggingface.co/Qdrant/Unicom-ViT-B-16) - [**Qdrant/Unicom-ViT-B-32**](https://huggingface.co/Qdrant/Unicom-ViT-B-32) +- [**nomic-ai/nomic-embed-vision-v1.5**](https://huggingface.co/nomic-ai/nomic-embed-vision-v1.5) ### Reranking diff --git a/src/image_embedding/impl.rs b/src/image_embedding/impl.rs index 104f85e..fe6f979 100644 --- a/src/image_embedding/impl.rs +++ b/src/image_embedding/impl.rs @@ -173,20 +173,46 @@ impl ImageEmbedding { let outputs = self.session.run(session_inputs)?; // Try to get the only output key - // If multiple, then default to `image_embeds` + // If multiple, then default to few known keys `image_embeds` and `last_hidden_state` let last_hidden_state_key = match outputs.len() { - 1 => outputs.keys().next().unwrap(), - _ => "image_embeds", + 1 => vec![outputs.keys().next().unwrap()], + _ => vec!["image_embeds", "last_hidden_state"], }; - // Extract and normalize embeddings - let output_data = outputs[last_hidden_state_key].try_extract_tensor::()?; - - let embeddings: Vec> = output_data - .rows() - .into_iter() - .map(|row| normalize(row.as_slice().unwrap())) - .collect(); + // Extract tensor and handle different dimensionalities + let output_data = last_hidden_state_key + .iter() + .find_map(|&key| { + outputs + .get(key) + .and_then(|v| v.try_extract_tensor::().ok()) + }) + .ok_or_else(|| anyhow!("Could not extract tensor from any known output key"))?; + let shape = output_data.shape(); + + let embeddings: Vec> = match shape.len() { + 3 => { + // For 3D output [batch_size, sequence_length, hidden_size] + // Take only the first token, sequence_length[0] (CLS token), embedding + // and return [batch_size, hidden_size] + (0..shape[0]) + .map(|batch_idx| { + let cls_embedding = + output_data.slice(ndarray::s![batch_idx, 0, ..]).to_vec(); + normalize(&cls_embedding) + }) + .collect() + } + 2 => { + // For 2D output [batch_size, hidden_size] + output_data + .rows() + .into_iter() + .map(|row| normalize(row.as_slice().unwrap())) + .collect() + } + _ => return Err(anyhow!("Unexpected output tensor shape: {:?}", shape)), + }; Ok(embeddings) }) diff --git a/src/models/image_embedding.rs b/src/models/image_embedding.rs index a13cfdd..90f1410 100644 --- a/src/models/image_embedding.rs +++ b/src/models/image_embedding.rs @@ -12,6 +12,8 @@ pub enum ImageEmbeddingModel { UnicomVitB16, /// Qdrant/Unicom-ViT-B-32 UnicomVitB32, + /// nomic-ai/nomic-embed-vision-v1.5 + NomicEmbedVisionV15, } pub fn models_list() -> Vec> { @@ -43,7 +45,14 @@ pub fn models_list() -> Vec> { description: String::from("Unicom Unicom-ViT-B-32 from open-metric-learning"), model_code: String::from("Qdrant/Unicom-ViT-B-32"), model_file: String::from("model.onnx"), - } + }, + ModelInfo { + model: ImageEmbeddingModel::NomicEmbedVisionV15, + dim: 768, + description: String::from("Nomic NomicEmbedVisionV15"), + model_code: String::from("nomic-ai/nomic-embed-vision-v1.5"), + model_file: String::from("onnx/model.onnx"), + }, ]; // TODO: Use when out in stable diff --git a/tests/embeddings.rs b/tests/embeddings.rs index 8c2d77d..f8c8093 100644 --- a/tests/embeddings.rs +++ b/tests/embeddings.rs @@ -486,6 +486,63 @@ fn test_image_embedding_model() { }); } +#[test] +#[ignore] +fn test_nomic_embed_vision_v1_5() { + fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 { + let dot_product = a.iter().zip(b).map(|(x, y)| x * y).sum::(); + let norm_a = a.iter().map(|x| x * x).sum::().sqrt(); + let norm_b = b.iter().map(|x| x * x).sum::().sqrt(); + dot_product / (norm_a * norm_b) + } + + fn cosine_similarity_matrix( + embeddings_a: &[Vec], + embeddings_b: &[Vec], + ) -> Vec> { + embeddings_a + .iter() + .map(|a| { + embeddings_b + .iter() + .map(|b| cosine_similarity(a, b)) + .collect() + }) + .collect() + } + + // Test the NomicEmbedVisionV15 model specifically because it outputs a 3D tensor with a different + // output key ('last_hidden_state') compared to other models. This test ensures our tensor extraction + // logic can handle both standard output keys and this model's specific naming convention. + let image_model = ImageEmbedding::try_new(ImageInitOptions::new( + fastembed::ImageEmbeddingModel::NomicEmbedVisionV15, + )) + .unwrap(); + + // tests/assets/image_0.png is a blue cat + // tests/assets/image_1.png is a red cat + let images = vec!["tests/assets/image_0.png", "tests/assets/image_1.png"]; + let image_embeddings = image_model.embed(images.clone(), None).unwrap(); + assert_eq!(image_embeddings.len(), images.len()); + + let text_model = TextEmbedding::try_new(InitOptions::new( + fastembed::EmbeddingModel::NomicEmbedTextV15, + )) + .unwrap(); + let texts = vec!["green cat", "blue cat", "red cat", "yellow cat", "dog"]; + let text_embeddings = text_model.embed(texts.clone(), None).unwrap(); + + // Generate similarity matrix + let similarity_matrix = cosine_similarity_matrix(&text_embeddings, &image_embeddings); + // Print the similarity matrix with text labels + for (i, row) in similarity_matrix.iter().enumerate() { + println!("{}: {:?}", texts[i], row); + } + + assert_eq!(text_embeddings.len(), texts.len()); + assert_eq!(text_embeddings[0].len(), 768); +} + fn clean_cache(model_code: String) { let repo = Repo::model(model_code); let cache_dir = format!("{}/{}", DEFAULT_CACHE_DIR, repo.folder_name());