From e82431d2ac0019d08b10fdc2e5da656163e1141a Mon Sep 17 00:00:00 2001 From: Prabir Shrestha Date: Sun, 8 Sep 2024 10:26:34 -0700 Subject: [PATCH] fix ollama-rs update (#219) --- src/embedding/ollama/ollama_embedder.rs | 47 ++++++++++++++++++------- 1 file changed, 34 insertions(+), 13 deletions(-) diff --git a/src/embedding/ollama/ollama_embedder.rs b/src/embedding/ollama/ollama_embedder.rs index 2619d4c9..9acdcd39 100644 --- a/src/embedding/ollama/ollama_embedder.rs +++ b/src/embedding/ollama/ollama_embedder.rs @@ -2,7 +2,13 @@ use std::sync::Arc; use crate::embedding::{embedder_trait::Embedder, EmbedderError}; use async_trait::async_trait; -use ollama_rs::{generation::options::GenerationOptions, Ollama as OllamaClient}; +use ollama_rs::{ + generation::{ + embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest}, + options::GenerationOptions, + }, + Ollama as OllamaClient, +}; #[derive(Debug)] pub struct OllamaEmbedder { @@ -50,16 +56,19 @@ impl Embedder for OllamaEmbedder { async fn embed_documents(&self, documents: &[String]) -> Result>, EmbedderError> { log::debug!("Embedding documents: {:?}", documents); - let mut embeddings = Vec::with_capacity(documents.len()); - - for doc in documents { - let res = self - .client - .generate_embeddings(self.model.clone(), doc.clone(), self.options.clone()) - .await?; + let response = self + .client + .generate_embeddings(GenerateEmbeddingsRequest::new( + self.model.clone(), + EmbeddingsInput::Multiple(documents.to_vec()), + )) + .await?; - embeddings.push(res.embeddings); - } + let embeddings = response + .embeddings + .into_iter() + .map(|embedding| embedding.into_iter().map(f64::from).collect()) + .collect(); Ok(embeddings) } @@ -67,12 +76,24 @@ impl Embedder for OllamaEmbedder { async fn embed_query(&self, text: &str) -> Result, EmbedderError> { log::debug!("Embedding query: {:?}", text); - let res = self + let response = self .client - .generate_embeddings(self.model.clone(), text.to_string(), self.options.clone()) + .generate_embeddings(GenerateEmbeddingsRequest::new( + self.model.clone(), + EmbeddingsInput::Single(text.into()), + )) .await?; - Ok(res.embeddings) + let embeddings = response + .embeddings + .into_iter() + .next() + .unwrap() + .into_iter() + .map(f64::from) + .collect(); + + Ok(embeddings) } }