From eb5bf4f2d4a08c74bf2d9998b286b10000fc7c4e Mon Sep 17 00:00:00 2001 From: Alejandro Serrano Date: Wed, 24 May 2023 14:40:41 +0200 Subject: [PATCH] Make embeddings optional in Lucene (#101) --- .../functional/xef/vectorstores/Lucene.kt | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/integrations/lucene/src/main/kotlin/com/xebia/functional/xef/vectorstores/Lucene.kt b/integrations/lucene/src/main/kotlin/com/xebia/functional/xef/vectorstores/Lucene.kt index 4401018e0..52b5ca2bb 100644 --- a/integrations/lucene/src/main/kotlin/com/xebia/functional/xef/vectorstores/Lucene.kt +++ b/integrations/lucene/src/main/kotlin/com/xebia/functional/xef/vectorstores/Lucene.kt @@ -1,6 +1,6 @@ package com.xebia.functional.xef.vectorstores -import arrow.fx.coroutines.Resource +import arrow.fx.coroutines.ResourceScope import arrow.fx.coroutines.autoCloseable import com.xebia.functional.xef.embeddings.Embedding import com.xebia.functional.xef.embeddings.Embeddings @@ -22,13 +22,13 @@ import org.apache.lucene.store.MMapDirectory open class Lucene( private val writer: IndexWriter, private val searcher: IndexSearcher, - private val embeddings: Embeddings, + private val embeddings: Embeddings?, private val similarity: VectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN ) : VectorStore, AutoCloseable { constructor( writer: IndexWriter, - embeddings: Embeddings, + embeddings: Embeddings?, similarity: VectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN ) : this(writer, IndexSearcher(DirectoryReader.open(writer)), embeddings, similarity) @@ -37,11 +37,11 @@ open class Lucene( override suspend fun addTexts(texts: List) = texts.forEach { - val embedding = embeddings.embedQuery(it, requestConfig) + val embedding = embeddings?.embedQuery(it, requestConfig) val doc = Document().apply { add(TextField("contents", it, Field.Store.YES)) - add(KnnFloatVectorField("embedding", embedding.toFloatArray(), similarity)) + if (embedding != null) add(KnnFloatVectorField("embedding", embedding.toFloatArray(), similarity)) } writer.addDocument(doc) } @@ -49,8 +49,10 @@ open class Lucene( override suspend fun similaritySearch(query: String, limit: Int): List = search(FuzzyQuery(Term("contents", query)), limit) - override suspend fun similaritySearchByVector(embedding: Embedding, limit: Int): List = - search(KnnFloatVectorQuery("embedding", embedding.data.toFloatArray(), limit), limit) + override suspend fun similaritySearchByVector(embedding: Embedding, limit: Int): List { + requireNotNull(embeddings) { "no embeddings were computed for this model" } + return search(KnnFloatVectorQuery("embedding", embedding.data.toFloatArray(), limit), limit) + } private fun search(q: Query, limit: Int): List = searcher.search(q, limit).scoreDocs.map { @@ -65,7 +67,7 @@ open class Lucene( class DirectoryLucene( private val directory: Directory, writerConfig: IndexWriterConfig = IndexWriterConfig(), - embeddings: Embeddings, + embeddings: Embeddings?, similarity: VectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN ) : Lucene(IndexWriter(directory, writerConfig), embeddings, similarity) { override fun close() { @@ -77,16 +79,17 @@ class DirectoryLucene( fun InMemoryLucene( path: Path, writerConfig: IndexWriterConfig = IndexWriterConfig(), - embeddings: Embeddings, + embeddings: Embeddings?, similarity: VectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN ): DirectoryLucene = DirectoryLucene(MMapDirectory(path), writerConfig, embeddings, similarity) fun InMemoryLuceneBuilder( path: Path, + useAIEmbeddings: Boolean = true, writerConfig: IndexWriterConfig = IndexWriterConfig(), similarity: VectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN -): (Embeddings) -> Resource = { embeddings -> - autoCloseable { InMemoryLucene(path, writerConfig, embeddings, similarity) } +): suspend ResourceScope.(Embeddings) -> DirectoryLucene = { embeddings -> + autoCloseable { InMemoryLucene(path, writerConfig, embeddings.takeIf { useAIEmbeddings }, similarity) } } fun List.toFloatArray(): FloatArray = flatMap { it.data }.toFloatArray()