diff --git a/src/embedding/ollama/ollama_embedder.rs b/src/embedding/ollama/ollama_embedder.rs index 2619d4c9..392274e7 100644 --- a/src/embedding/ollama/ollama_embedder.rs +++ b/src/embedding/ollama/ollama_embedder.rs @@ -2,7 +2,14 @@ use std::sync::Arc; use crate::embedding::{embedder_trait::Embedder, EmbedderError}; use async_trait::async_trait; -use ollama_rs::{generation::options::GenerationOptions, Ollama as OllamaClient}; +use ollama_rs::{ + generation::{ + embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest}, + options::GenerationOptions, + }, + Ollama as OllamaClient, +}; +use reqwest::Response; #[derive(Debug)] pub struct OllamaEmbedder { @@ -50,16 +57,24 @@ impl Embedder for OllamaEmbedder { async fn embed_documents(&self, documents: &[String]) -> Result>, EmbedderError> { log::debug!("Embedding documents: {:?}", documents); - let mut embeddings = Vec::with_capacity(documents.len()); - - for doc in documents { - let res = self - .client - .generate_embeddings(self.model.clone(), doc.clone(), self.options.clone()) - .await?; + let response = self + .client + .generate_embeddings(GenerateEmbeddingsRequest::new( + self.model.clone(), + EmbeddingsInput::Multiple(documents.iter().map(|doc| doc.clone()).collect()), + )) + .await?; - embeddings.push(res.embeddings); - } + let embeddings = response + .embeddings + .into_iter() + .map(|embedding| { + embedding + .into_iter() + .map(|x| x as f64) + .collect::>() + }) + .collect(); Ok(embeddings) } @@ -67,12 +82,24 @@ impl Embedder for OllamaEmbedder { async fn embed_query(&self, text: &str) -> Result, EmbedderError> { log::debug!("Embedding query: {:?}", text); - let res = self + let response = self .client - .generate_embeddings(self.model.clone(), text.to_string(), self.options.clone()) + .generate_embeddings(GenerateEmbeddingsRequest::new( + self.model.clone(), + EmbeddingsInput::Single(text.into()), + )) .await?; - Ok(res.embeddings) + let embeddings = response + .embeddings + .into_iter() + .next() + .unwrap() + .into_iter() + .map(f64::from) + .collect(); + + Ok(embeddings) } }