Skip to content

Commit

Permalink
fix ollama-rs update (#219)
Browse files Browse the repository at this point in the history
  • Loading branch information
prabirshrestha authored Sep 8, 2024
1 parent fa268e3 commit e82431d
Showing 1 changed file with 34 additions and 13 deletions.
47 changes: 34 additions & 13 deletions src/embedding/ollama/ollama_embedder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@ use std::sync::Arc;

use crate::embedding::{embedder_trait::Embedder, EmbedderError};
use async_trait::async_trait;
use ollama_rs::{generation::options::GenerationOptions, Ollama as OllamaClient};
use ollama_rs::{
generation::{
embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest},
options::GenerationOptions,
},
Ollama as OllamaClient,
};

#[derive(Debug)]
pub struct OllamaEmbedder {
Expand Down Expand Up @@ -50,29 +56,44 @@ impl Embedder for OllamaEmbedder {
async fn embed_documents(&self, documents: &[String]) -> Result<Vec<Vec<f64>>, EmbedderError> {
log::debug!("Embedding documents: {:?}", documents);

let mut embeddings = Vec::with_capacity(documents.len());

for doc in documents {
let res = self
.client
.generate_embeddings(self.model.clone(), doc.clone(), self.options.clone())
.await?;
let response = self
.client
.generate_embeddings(GenerateEmbeddingsRequest::new(
self.model.clone(),
EmbeddingsInput::Multiple(documents.to_vec()),
))
.await?;

embeddings.push(res.embeddings);
}
let embeddings = response
.embeddings
.into_iter()
.map(|embedding| embedding.into_iter().map(f64::from).collect())
.collect();

Ok(embeddings)
}

async fn embed_query(&self, text: &str) -> Result<Vec<f64>, EmbedderError> {
log::debug!("Embedding query: {:?}", text);

let res = self
let response = self
.client
.generate_embeddings(self.model.clone(), text.to_string(), self.options.clone())
.generate_embeddings(GenerateEmbeddingsRequest::new(
self.model.clone(),
EmbeddingsInput::Single(text.into()),
))
.await?;

Ok(res.embeddings)
let embeddings = response
.embeddings
.into_iter()
.next()
.unwrap()
.into_iter()
.map(f64::from)
.collect();

Ok(embeddings)
}
}

Expand Down

0 comments on commit e82431d

Please sign in to comment.