Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds support for nomic-embed-vision-v1.5 image embeddings model #129

Merged
merged 2 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ The default model is Flag Embedding, which is top of the [MTEB](https://huggingf
- [**sentence-transformers/paraphrase-MiniLM-L12-v2**](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L12-v2)
- [**sentence-transformers/paraphrase-multilingual-mpnet-base-v2**](https://huggingface.co/sentence-transformers/paraphrase-multilingual-mpnet-base-v2)
- [**nomic-ai/nomic-embed-text-v1**](https://huggingface.co/nomic-ai/nomic-embed-text-v1)
- [**nomic-ai/nomic-embed-text-v1.5**](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5)
- [**nomic-ai/nomic-embed-text-v1.5**](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5) - pairs with the image model nomic-embed-vision-v1.5 for image-to-text search
- [**intfloat/multilingual-e5-small**](https://huggingface.co/intfloat/multilingual-e5-small)
- [**intfloat/multilingual-e5-base**](https://huggingface.co/intfloat/multilingual-e5-base)
- [**intfloat/multilingual-e5-large**](https://huggingface.co/intfloat/multilingual-e5-large)
Expand All @@ -59,6 +59,7 @@ The default model is Flag Embedding, which is top of the [MTEB](https://huggingf
- [**Qdrant/resnet50-onnx**](https://huggingface.co/Qdrant/resnet50-onnx)
- [**Qdrant/Unicom-ViT-B-16**](https://huggingface.co/Qdrant/Unicom-ViT-B-16)
- [**Qdrant/Unicom-ViT-B-32**](https://huggingface.co/Qdrant/Unicom-ViT-B-32)
- [**nomic-ai/nomic-embed-vision-v1.5**](https://huggingface.co/nomic-ai/nomic-embed-vision-v1.5)

### Reranking

Expand Down
48 changes: 37 additions & 11 deletions src/image_embedding/impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,20 +173,46 @@ impl ImageEmbedding {
let outputs = self.session.run(session_inputs)?;

// Try to get the only output key
// If multiple, then default to `image_embeds`
// If multiple, then default to few known keys `image_embeds` and `last_hidden_state`
let last_hidden_state_key = match outputs.len() {
1 => outputs.keys().next().unwrap(),
_ => "image_embeds",
1 => vec![outputs.keys().next().unwrap()],
_ => vec!["image_embeds", "last_hidden_state"],
};

// Extract and normalize embeddings
let output_data = outputs[last_hidden_state_key].try_extract_tensor::<f32>()?;

let embeddings: Vec<Vec<f32>> = output_data
.rows()
.into_iter()
.map(|row| normalize(row.as_slice().unwrap()))
.collect();
// Extract tensor and handle different dimensionalities
let output_data = last_hidden_state_key
.iter()
.find_map(|&key| {
outputs
.get(key)
.and_then(|v| v.try_extract_tensor::<f32>().ok())
})
.ok_or_else(|| anyhow!("Could not extract tensor from any known output key"))?;
let shape = output_data.shape();

let embeddings: Vec<Vec<f32>> = match shape.len() {
3 => {
// For 3D output [batch_size, sequence_length, hidden_size]
// Take only the first token, sequence_length[0] (CLS token), embedding
// and return [batch_size, hidden_size]
(0..shape[0])
.map(|batch_idx| {
let cls_embedding =
output_data.slice(ndarray::s![batch_idx, 0, ..]).to_vec();
normalize(&cls_embedding)
})
.collect()
}
2 => {
// For 2D output [batch_size, hidden_size]
output_data
.rows()
.into_iter()
.map(|row| normalize(row.as_slice().unwrap()))
.collect()
}
_ => return Err(anyhow!("Unexpected output tensor shape: {:?}", shape)),
};

Ok(embeddings)
})
Expand Down
11 changes: 10 additions & 1 deletion src/models/image_embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ pub enum ImageEmbeddingModel {
UnicomVitB16,
/// Qdrant/Unicom-ViT-B-32
UnicomVitB32,
/// nomic-ai/nomic-embed-vision-v1.5
NomicEmbedVisionV15,
}

pub fn models_list() -> Vec<ModelInfo<ImageEmbeddingModel>> {
Expand Down Expand Up @@ -43,7 +45,14 @@ pub fn models_list() -> Vec<ModelInfo<ImageEmbeddingModel>> {
description: String::from("Unicom Unicom-ViT-B-32 from open-metric-learning"),
model_code: String::from("Qdrant/Unicom-ViT-B-32"),
model_file: String::from("model.onnx"),
}
},
ModelInfo {
model: ImageEmbeddingModel::NomicEmbedVisionV15,
dim: 768,
description: String::from("Nomic NomicEmbedVisionV15"),
model_code: String::from("nomic-ai/nomic-embed-vision-v1.5"),
model_file: String::from("onnx/model.onnx"),
},
];

// TODO: Use when out in stable
Expand Down
57 changes: 57 additions & 0 deletions tests/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,63 @@ fn test_image_embedding_model() {
});
}

#[test]
#[ignore]
Anush008 marked this conversation as resolved.
Show resolved Hide resolved
fn test_nomic_embed_vision_v1_5() {
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot_product = a.iter().zip(b).map(|(x, y)| x * y).sum::<f32>();
let norm_a = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b = b.iter().map(|x| x * x).sum::<f32>().sqrt();
dot_product / (norm_a * norm_b)
}

fn cosine_similarity_matrix(
embeddings_a: &[Vec<f32>],
embeddings_b: &[Vec<f32>],
) -> Vec<Vec<f32>> {
embeddings_a
.iter()
.map(|a| {
embeddings_b
.iter()
.map(|b| cosine_similarity(a, b))
.collect()
})
.collect()
}

// Test the NomicEmbedVisionV15 model specifically because it outputs a 3D tensor with a different
// output key ('last_hidden_state') compared to other models. This test ensures our tensor extraction
// logic can handle both standard output keys and this model's specific naming convention.
let image_model = ImageEmbedding::try_new(ImageInitOptions::new(
fastembed::ImageEmbeddingModel::NomicEmbedVisionV15,
))
.unwrap();

// tests/assets/image_0.png is a blue cat
// tests/assets/image_1.png is a red cat
let images = vec!["tests/assets/image_0.png", "tests/assets/image_1.png"];
let image_embeddings = image_model.embed(images.clone(), None).unwrap();
assert_eq!(image_embeddings.len(), images.len());

let text_model = TextEmbedding::try_new(InitOptions::new(
fastembed::EmbeddingModel::NomicEmbedTextV15,
))
.unwrap();
let texts = vec!["green cat", "blue cat", "red cat", "yellow cat", "dog"];
let text_embeddings = text_model.embed(texts.clone(), None).unwrap();

// Generate similarity matrix
let similarity_matrix = cosine_similarity_matrix(&text_embeddings, &image_embeddings);
// Print the similarity matrix with text labels
for (i, row) in similarity_matrix.iter().enumerate() {
println!("{}: {:?}", texts[i], row);
}

assert_eq!(text_embeddings.len(), texts.len());
assert_eq!(text_embeddings[0].len(), 768);
}

fn clean_cache(model_code: String) {
let repo = Repo::model(model_code);
let cache_dir = format!("{}/{}", DEFAULT_CACHE_DIR, repo.folder_name());
Expand Down