diff --git a/build.gradle.kts b/build.gradle.kts index 49de77cc2..51a00b7bd 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -51,6 +51,8 @@ kotlin { implementation(libs.kotlinx.serialization.json) implementation(libs.bundles.ktor.client) implementation(libs.okio) + implementation(libs.uuid) + implementation(libs.klogging) } } @@ -63,9 +65,17 @@ kotlin { implementation(libs.kotest.assertions.arrow) } } + val jvmMain by getting { + dependencies { + implementation(libs.hikari) + implementation(libs.postgresql) + } + } val jvmTest by getting { dependencies { implementation(libs.kotest.junit5) + implementation(libs.kotest.testcontainers) + implementation(libs.testcontainers.postgresql) } } } diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 60bd1c084..9c0b3e511 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -7,7 +7,13 @@ ktor = "2.2.2" spotless = "6.18.0" okio = "3.3.0" kotest = "5.5.4" +kotest-testcontainers = "1.3.4" kotest-arrow = "1.3.0" +klogging = "4.0.0-beta-22" +uuid = "0.0.18" +postgresql = "42.5.1" +testcontainers = "1.17.6" +hikari = "5.0.1" [libraries] arrow-fx = { module = "io.arrow-kt:arrow-fx-coroutines", version.ref = "arrow" } @@ -23,7 +29,13 @@ kotest-assertions = { module = "io.kotest:kotest-assertions-core", version.ref = kotest-framework = { module = "io.kotest:kotest-framework-engine", version.ref = "kotest" } kotest-property = { module = "io.kotest:kotest-property", version.ref = "kotest" } kotest-junit5 = { module = "io.kotest:kotest-runner-junit5", version.ref = "kotest" } +kotest-testcontainers = { module = "io.kotest.extensions:kotest-extensions-testcontainers", version.ref = "kotest-testcontainers" } kotest-assertions-arrow = { module = "io.kotest.extensions:kotest-assertions-arrow", version.ref = "kotest-arrow" } +uuid = { module = "app.softwork:kotlinx-uuid-core", version.ref = "uuid" } +klogging = { module = "io.github.oshai:kotlin-logging", version.ref = "klogging" } +hikari = { module = "com.zaxxer:HikariCP", version.ref = "hikari" } +postgresql = { module = "org.postgresql:postgresql", version.ref = "postgresql" } +testcontainers-postgresql = { module = "org.testcontainers:postgresql", version.ref = "testcontainers" } [bundles] ktor-client = [ diff --git a/src/commonMain/kotlin/com/xebia/functional/embeddings/Embeddings.kt b/src/commonMain/kotlin/com/xebia/functional/embeddings/Embeddings.kt new file mode 100644 index 000000000..49c074bb3 --- /dev/null +++ b/src/commonMain/kotlin/com/xebia/functional/embeddings/Embeddings.kt @@ -0,0 +1,12 @@ +package com.xebia.functional.embeddings + +import com.xebia.functional.llm.openai.RequestConfig + +data class Embedding(val data: List) + +interface Embeddings { + suspend fun embedDocuments(texts: List, chunkSize: Int?, requestConfig: RequestConfig): List + suspend fun embedQuery(text: String, requestConfig: RequestConfig): List + + companion object +} diff --git a/src/commonMain/kotlin/com/xebia/functional/embeddings/OpenAIEmbeddings.kt b/src/commonMain/kotlin/com/xebia/functional/embeddings/OpenAIEmbeddings.kt new file mode 100644 index 000000000..63ababc01 --- /dev/null +++ b/src/commonMain/kotlin/com/xebia/functional/embeddings/OpenAIEmbeddings.kt @@ -0,0 +1,51 @@ +package com.xebia.functional.embeddings + +import arrow.fx.coroutines.parMap +import arrow.resilience.retry +import com.xebia.functional.env.OpenAIConfig +import com.xebia.functional.llm.openai.EmbeddingRequest +import com.xebia.functional.llm.openai.OpenAIClient +import com.xebia.functional.llm.openai.RequestConfig +import io.github.oshai.KLogger +import kotlin.time.ExperimentalTime + +@ExperimentalTime +class OpenAIEmbeddings( + private val config: OpenAIConfig, + private val oaiClient: OpenAIClient, + private val logger: KLogger +) : Embeddings { + + override suspend fun embedDocuments( + texts: List, + chunkSize: Int?, + requestConfig: RequestConfig + ): List = + chunkedEmbedDocuments(texts, chunkSize ?: config.chunkSize, requestConfig) + + override suspend fun embedQuery(text: String, requestConfig: RequestConfig): List = + if (text.isNotEmpty()) embedDocuments(listOf(text), null, requestConfig) else emptyList() + + private suspend fun chunkedEmbedDocuments( + texts: List, + chunkSize: Int, + requestConfig: RequestConfig + ): List = + if (texts.isEmpty()) emptyList() + else texts.chunked(chunkSize) + .parMap { createEmbeddingWithRetry(it, requestConfig) } + .flatten() + + private suspend fun createEmbeddingWithRetry(texts: List, requestConfig: RequestConfig): List = + kotlin.runCatching { + config.retryConfig.schedule() + .log { retriesSoFar, _ -> logger.warn { "Open AI call failed. So far we have retried $retriesSoFar times." } } + .retry { + oaiClient.createEmbeddings(EmbeddingRequest(requestConfig.model.name, texts, requestConfig.user.id)) + .data.map { Embedding(it.embedding) } + } + }.getOrElse { + logger.warn { "Open AI call failed. Giving up after ${config.retryConfig.maxRetries} retries" } + throw it + } +} \ No newline at end of file diff --git a/src/commonMain/kotlin/com/xebia/functional/env/config.kt b/src/commonMain/kotlin/com/xebia/functional/env/config.kt index 9cd8e6897..0f77dd465 100644 --- a/src/commonMain/kotlin/com/xebia/functional/env/config.kt +++ b/src/commonMain/kotlin/com/xebia/functional/env/config.kt @@ -17,11 +17,9 @@ data class Env(val openAI: OpenAIConfig, val huggingFace: HuggingFaceConfig) data class OpenAIConfig(val token: String, val chunkSize: Int, val retryConfig: RetryConfig) data class RetryConfig(val backoff: Duration, val maxRetries: Long) { - fun schedule(): Schedule = + fun schedule(): Schedule = Schedule.recurs(maxRetries) - .and(Schedule.exponential(backoff)) - .jittered(0.75, 1.25) - .map { } + .zipLeft(Schedule.exponential(backoff).jittered(0.75, 1.25)) } data class HuggingFaceConfig(val token: String, val baseUrl: KUrl) diff --git a/src/commonMain/kotlin/com/xebia/functional/llm/openai/models.kt b/src/commonMain/kotlin/com/xebia/functional/llm/openai/models.kt index 9e480df88..fc8316380 100644 --- a/src/commonMain/kotlin/com/xebia/functional/llm/openai/models.kt +++ b/src/commonMain/kotlin/com/xebia/functional/llm/openai/models.kt @@ -1,8 +1,21 @@ package com.xebia.functional.llm.openai +import kotlin.jvm.JvmInline import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable +enum class EmbeddingModel(name: String) { + TextEmbeddingAda002("text-embedding-ada-002") +} + +data class RequestConfig(val model: EmbeddingModel, val user: User) { + companion object { + @JvmInline + value class User(val id: String) + } +} + + @Serializable data class CompletionChoice(val text: String, val index: Int, val finishReason: String) @@ -38,7 +51,7 @@ data class EmbeddingResult( ) @Serializable -class Embedding(val `object`: String, val embedding: List, val index: Int) +class Embedding(val `object`: String, val embedding: List, val index: Int) @Serializable data class Usage( diff --git a/src/commonMain/kotlin/com/xebia/functional/model.kt b/src/commonMain/kotlin/com/xebia/functional/model.kt new file mode 100644 index 000000000..3d8164f57 --- /dev/null +++ b/src/commonMain/kotlin/com/xebia/functional/model.kt @@ -0,0 +1,3 @@ +package com.xebia.functional + +data class Document(val content: String) diff --git a/src/commonMain/kotlin/com/xebia/functional/vectorstores/VectorStore.kt b/src/commonMain/kotlin/com/xebia/functional/vectorstores/VectorStore.kt new file mode 100644 index 000000000..3d0e4e29c --- /dev/null +++ b/src/commonMain/kotlin/com/xebia/functional/vectorstores/VectorStore.kt @@ -0,0 +1,45 @@ +package com.xebia.functional.vectorstores + +import com.xebia.functional.Document +import com.xebia.functional.embeddings.Embedding +import kotlin.jvm.JvmInline +import kotlinx.uuid.UUID + +@JvmInline +value class DocumentVectorId(val id: UUID) + +interface VectorStore { + /** + * Add texts to the vector store after running them through the embeddings + * + * @param texts list of text to add to the vector store + * @return a list of IDs from adding the texts to the vector store + */ + suspend fun addTexts(texts: List): List + + /** + * Add documents to the vector store after running them through the embeddings + * + * @param documents list of Documents to add to the vector store + * @return a list of IDs from adding the documents to the vector store + */ + suspend fun addDocuments(documents: List): List + + /** + * Return the docs most similar to the query + * + * @param query text to use to search for similar documents + * @param limit number of documents to return + * @return a list of Documents most similar to query + */ + suspend fun similaritySearch(query: String, limit: Int): List + + /** + * Return the docs most similar to the embedding + * + * @param embedding embedding vector to use to search for similar documents + * @param limit number of documents to return + * @return list of Documents most similar to the embedding + */ + suspend fun similaritySearchByVector(embedding: Embedding, limit: Int): List +} \ No newline at end of file diff --git a/src/commonMain/kotlin/com/xebia/functional/vectorstores/postgres.kt b/src/commonMain/kotlin/com/xebia/functional/vectorstores/postgres.kt new file mode 100644 index 000000000..2b7fcd5f8 --- /dev/null +++ b/src/commonMain/kotlin/com/xebia/functional/vectorstores/postgres.kt @@ -0,0 +1,76 @@ +package com.xebia.functional.vectorstores + +import kotlinx.uuid.UUID + +data class PGCollection(val uuid: UUID, val collectionName: String) + +enum class PGDistanceStrategy(val strategy: String) { + Euclidean("<->"), InnerProduct("<#>"), CosineDistance("<=>") +} + +val createCollections: String = + """CREATE TABLE langchain4k_collections ( + uuid TEXT PRIMARY KEY, + name TEXT UNIQUE NOT NULL + );""".trimIndent() + +val createEmbeddings: String = + """CREATE TABLE langchain4k_embeddings ( + uuid TEXT PRIMARY KEY, + collection_id TEXT REFERENCES langchain4k_collections(uuid), + embedding BLOB, + content TEXT + );""".trimIndent() + +val addVectorExtension: String = + "CREATE EXTENSION IF NOT EXISTS vector;" + +val createCollectionsTable: String = + """CREATE TABLE IF NOT EXISTS langchain4k_collections ( + uuid TEXT PRIMARY KEY, + name TEXT UNIQUE NOT NULL + );""".trimIndent() + +fun createEmbeddingTable(vectorSize: Int): String = + """CREATE TABLE IF NOT EXISTS langchain4k_embeddings ( + uuid TEXT PRIMARY KEY, + collection_id TEXT REFERENCES langchain4k_collections(uuid), + embedding vector($vectorSize), + content TEXT + );""".trimIndent() + +val addNewCollection: String = + """INSERT INTO langchain4k_collections(uuid, name) + VALUES (?, ?) + ON CONFLICT DO NOTHING;""".trimIndent() + +val deleteCollection: String = + """DELETE FROM langchain4k_collections + WHERE uuid = ?;""".trimIndent() + +val getCollection: String = + """SELECT * FROM langchain4k_collections + WHERE name = ?;""".trimIndent() + +val getCollectionById: String = + """SELECT * FROM langchain4k_collections + WHERE uuid = ?;""".trimIndent() + +val addNewDocument: String = + """INSERT INTO langchain4k_embeddings(uuid, collection_id, embedding, content) + VALUES (?, ?, ?, ?);""".trimIndent() + +val deleteCollectionDocs: String = + """DELETE FROM langchain4k_embeddings + WHERE collection_id = ?;""".trimIndent() + +val addNewText: String = + """INSERT INTO langchain4k_embeddings(uuid, collection_id, embedding, content) + VALUES (?, ?, ?::vector, ?);""".trimIndent() + +fun searchSimilarDocument(distance: PGDistanceStrategy): String = + """SELECT content FROM langchain4k_embeddings + WHERE collection_id = ? + ORDER BY embedding + ${distance.strategy} ?::vector + LIMIT ?;""".trimIndent() diff --git a/src/commonTest/kotlin/com/xebia/functional/embeddings/Mock.kt b/src/commonTest/kotlin/com/xebia/functional/embeddings/Mock.kt new file mode 100644 index 000000000..d21954fa9 --- /dev/null +++ b/src/commonTest/kotlin/com/xebia/functional/embeddings/Mock.kt @@ -0,0 +1,23 @@ +package com.xebia.functional.embeddings + +import com.xebia.functional.llm.openai.RequestConfig + +fun Embeddings.Companion.mock( + embedDocuments: suspend (texts: List, chunkSize: Int?, config: RequestConfig) -> List = { _, _, _ -> + listOf(Embedding(listOf(1.0f, 2.0f, 3.0f)), Embedding(listOf(4.0f, 5.0f, 6.0f))) + }, + embedQuery: suspend (text: String, config: RequestConfig) -> List = { text, _ -> + when (text) { + "foo" -> listOf(Embedding(listOf(1.0f, 2.0f, 3.0f))) + "bar" -> listOf(Embedding(listOf(4.0f, 5.0f, 6.0f))) + "baz" -> listOf() + else -> listOf() + } + } +): Embeddings = object : Embeddings { + override suspend fun embedDocuments(texts: List, chunkSize: Int?, requestConfig: RequestConfig): List = + embedDocuments(texts, chunkSize, requestConfig) + + override suspend fun embedQuery(text: String, requestConfig: RequestConfig): List = + embedQuery(text, requestConfig) +} diff --git a/src/jvmMain/kotlin/com/xebia/functional/JDBCSyntax.kt b/src/jvmMain/kotlin/com/xebia/functional/JDBCSyntax.kt new file mode 100644 index 000000000..ae9aa157d --- /dev/null +++ b/src/jvmMain/kotlin/com/xebia/functional/JDBCSyntax.kt @@ -0,0 +1,107 @@ +package com.xebia.functional + +import arrow.core.raise.NullableRaise +import arrow.core.raise.nullable +import arrow.fx.coroutines.ResourceScope +import arrow.fx.coroutines.autoCloseable +import arrow.fx.coroutines.resourceScope +import java.sql.Connection +import java.sql.PreparedStatement +import java.sql.ResultSet +import java.sql.Types +import javax.sql.DataSource + +suspend fun DataSource.connection(block: suspend JDBCSyntax.() -> A): A = + resourceScope { + val conn = autoCloseable { connection } + JDBCSyntax(conn, this).block() + } + +class JDBCSyntax(conn: Connection, resourceScope: ResourceScope) : ResourceScope by resourceScope, Connection by conn { + + suspend fun prepareStatement( + sql: String, + binders: (SqlPreparedStatement.() -> Unit)? = null + ): PreparedStatement = autoCloseable { + prepareStatement(sql) + .apply { if (binders != null) SqlPreparedStatement(this).binders() } + } + + suspend fun update( + sql: String, + binders: (SqlPreparedStatement.() -> Unit)? = null, + ): Unit { + val statement = prepareStatement(sql, binders) + statement.executeUpdate() + } + + suspend fun queryOneOrNull( + sql: String, + binders: (SqlPreparedStatement.() -> Unit)? = null, + mapper: NullableSqlCursor.() -> A + ): A? { + val statement = prepareStatement(sql, binders) + val rs = autoCloseable { statement.executeQuery() } + return if (rs.next()) nullable { mapper(NullableSqlCursor(rs, this)) } + else null + } + + suspend fun queryAsList( + sql: String, + binders: (SqlPreparedStatement.() -> Unit)? = null, + mapper: NullableSqlCursor.() -> A? + ): List { + val statement = prepareStatement(sql, binders) + val rs = autoCloseable { statement.executeQuery() } + return buildList { + while (rs.next()) { + nullable { mapper(NullableSqlCursor(rs, this)) }?.let(::add) + } + } + } + + class SqlPreparedStatement(private val preparedStatement: PreparedStatement) { + private var index: Int = 1 + + fun bind(short: Short?): Unit = bind(short?.toLong()) + fun bind(byte: Byte?): Unit = bind(byte?.toLong()) + fun bind(int: Int?): Unit = bind(int?.toLong()) + fun bind(char: Char?): Unit = bind(char?.toString()) + + fun bind(bytes: ByteArray?): Unit = + if (bytes == null) preparedStatement.setNull(index++, Types.BLOB) + else preparedStatement.setBytes(index++, bytes) + + fun bind(long: Long?): Unit = + if (long == null) preparedStatement.setNull(index++, Types.INTEGER) + else preparedStatement.setLong(index++, long) + + fun bind(double: Double?): Unit = + if (double == null) preparedStatement.setNull(index++, Types.REAL) + else preparedStatement.setDouble(index++, double) + + fun bind(string: String?): Unit = + if (string == null) preparedStatement.setNull(index++, Types.VARCHAR) + else preparedStatement.setString(index++, string) + } + + class SqlCursor(private val resultSet: ResultSet) { + private var index: Int = 1 + fun int(): Int? = long()?.toInt() + fun string(): String? = resultSet.getString(index++) + fun bytes(): ByteArray? = resultSet.getBytes(index++) + fun long(): Long? = resultSet.getLong(index++).takeUnless { resultSet.wasNull() } + fun double(): Double? = resultSet.getDouble(index++).takeUnless { resultSet.wasNull() } + fun nextRow(): Boolean = resultSet.next() + } + + class NullableSqlCursor(private val resultSet: ResultSet, private val raise: NullableRaise) { + private var index: Int = 1 + fun int(): Int = long().toInt() + fun string(): String = raise.ensureNotNull(resultSet.getString(index++)) + fun bytes(): ByteArray = raise.ensureNotNull(resultSet.getBytes(index++)) + fun long(): Long = raise.ensureNotNull(resultSet.getLong(index++).takeUnless { resultSet.wasNull() }) + fun double(): Double = raise.ensureNotNull(resultSet.getDouble(index++).takeUnless { resultSet.wasNull() }) + fun nextRow(): Boolean = resultSet.next() + } +} diff --git a/src/jvmMain/kotlin/com/xebia/functional/PGVectorStore.kt b/src/jvmMain/kotlin/com/xebia/functional/PGVectorStore.kt new file mode 100644 index 000000000..679c017e4 --- /dev/null +++ b/src/jvmMain/kotlin/com/xebia/functional/PGVectorStore.kt @@ -0,0 +1,101 @@ +package com.xebia.functional + +import com.xebia.functional.embeddings.Embedding +import com.xebia.functional.embeddings.Embeddings +import com.xebia.functional.llm.openai.RequestConfig +import com.xebia.functional.vectorstores.DocumentVectorId +import com.xebia.functional.vectorstores.PGCollection +import com.xebia.functional.vectorstores.PGDistanceStrategy +import com.xebia.functional.vectorstores.VectorStore +import com.xebia.functional.vectorstores.addNewCollection +import com.xebia.functional.vectorstores.addNewText +import com.xebia.functional.vectorstores.addVectorExtension +import com.xebia.functional.vectorstores.createCollectionsTable +import com.xebia.functional.vectorstores.createEmbeddingTable +import com.xebia.functional.vectorstores.deleteCollection +import com.xebia.functional.vectorstores.deleteCollectionDocs +import com.xebia.functional.vectorstores.getCollection +import com.xebia.functional.vectorstores.searchSimilarDocument +import javax.sql.DataSource +import kotlinx.uuid.UUID +import kotlinx.uuid.generateUUID + +class PGVectorStore( + private val vectorSize: Int, + private val dataSource: DataSource, + private val embeddings: Embeddings, + private val collectionName: String, + private val distanceStrategy: PGDistanceStrategy, + private val preDeleteCollection: Boolean, + private val requestConfig: RequestConfig, + private val chunckSize: Int? +) : VectorStore { + + suspend fun JDBCSyntax.getCollection(collectionName: String): PGCollection = + queryOneOrNull(getCollection, + { bind(collectionName) } + ) { PGCollection(UUID(string()), string()) } + ?: throw IllegalStateException("Collection '$collectionName' not found") + + suspend fun JDBCSyntax.deleteCollection() { + if (preDeleteCollection) { + val collection = getCollection(collectionName) + update(deleteCollectionDocs) { bind(collection.uuid.toString()) } + update(deleteCollection) { bind(collection.uuid.toString()) } + } + } + + suspend fun initialDbSetup(): Unit = dataSource.connection { + update(addVectorExtension) + update(createCollectionsTable) + update(createEmbeddingTable(vectorSize)) + deleteCollection() + } + + suspend fun createCollection(): Unit = dataSource.connection { + val xa = UUID.generateUUID() + update(addNewCollection) { + bind(xa.toString()) + bind(collectionName) + } + } + + override suspend fun addTexts(texts: List): List = dataSource.connection { + val embeddings = embeddings.embedDocuments(texts, chunckSize, requestConfig) + val collection = getCollection(collectionName) + texts.zip(embeddings) { text, embedding -> + val uuid = UUID.generateUUID() + update(addNewText) { + bind(uuid.toString()) + bind(collection.uuid.toString()) + bind(embedding.data.toString()) + bind(text) + } + DocumentVectorId(uuid) + } + } + + override suspend fun addDocuments(documents: List): List = + addTexts(documents.map(Document::content)) + + override suspend fun similaritySearch(query: String, limit: Int): List = dataSource.connection { + val embeddings = embeddings.embedQuery(query, requestConfig) + .ifEmpty { throw IllegalStateException("Embedding for text: '$query', has not been properly generated") } + val collection = getCollection(collectionName) + queryAsList(searchSimilarDocument(distanceStrategy), { + bind(collection.uuid.toString()) + bind(embeddings[0].data.toString()) + bind(limit) + }) { Document(string()) } + } + + override suspend fun similaritySearchByVector(embedding: Embedding, limit: Int): List = + dataSource.connection { + val collection = getCollection(collectionName) + queryAsList(searchSimilarDocument(distanceStrategy), { + bind(collection.uuid.toString()) + bind(embedding.data.toString()) + bind(limit) + }) { Document(string()) } + } +} diff --git a/src/jvmTest/kotlin/com/xebia/functional/PGVectorStoreSpec.kt b/src/jvmTest/kotlin/com/xebia/functional/PGVectorStoreSpec.kt new file mode 100644 index 000000000..96d72873c --- /dev/null +++ b/src/jvmTest/kotlin/com/xebia/functional/PGVectorStoreSpec.kt @@ -0,0 +1,91 @@ +package com.xebia.functional + +import com.xebia.functional.embeddings.Embedding +import com.xebia.functional.embeddings.Embeddings +import com.xebia.functional.embeddings.mock +import com.xebia.functional.llm.openai.EmbeddingModel +import com.xebia.functional.llm.openai.RequestConfig +import com.xebia.functional.vectorstores.PGDistanceStrategy +import com.zaxxer.hikari.HikariConfig +import com.zaxxer.hikari.HikariDataSource +import io.kotest.core.extensions.install +import io.kotest.core.spec.style.StringSpec +import io.kotest.extensions.testcontainers.SharedTestContainerExtension +import io.kotest.matchers.shouldBe +import org.junit.jupiter.api.assertThrows +import org.testcontainers.containers.PostgreSQLContainer +import org.testcontainers.utility.DockerImageName + +val postgres: PostgreSQLContainer = + PostgreSQLContainer(DockerImageName.parse("ankane/pgvector").asCompatibleSubstituteFor("postgres")) + +class PGVectorStoreSpec : StringSpec({ + + val container = install(SharedTestContainerExtension(postgres)) + val dataSource = autoClose(HikariDataSource(HikariConfig().apply { + jdbcUrl = container.jdbcUrl + username = container.username + password = container.password + driverClassName = "org.postgresql.Driver" + })) + + val pg = PGVectorStore( + vectorSize = 3, + dataSource = dataSource, + embeddings = Embeddings.mock(), + collectionName = "test_collection", + distanceStrategy = PGDistanceStrategy.Euclidean, + preDeleteCollection = false, + requestConfig = RequestConfig(EmbeddingModel.TextEmbeddingAda002, RequestConfig.Companion.User("user")), + chunckSize = null + ) + + "initialDbSetup should configure the DB properly" { + pg.initialDbSetup() + } + + "addTexts should fail with a CollectionNotFoundError if collection isn't present in the DB" { + assertThrows { + pg.addTexts(listOf("foo", "bar")) + }.message shouldBe "Collection 'test_collection' not found" + } + + "similaritySearch should fail with a CollectionNotFoundError if collection isn't present in the DB" { + assertThrows { + pg.similaritySearch("foo", 2) + }.message shouldBe "Collection 'test_collection' not found" + } + + "createCollection should create collection" { + pg.createCollection() + } + + "addTexts should return a list of 2 elements" { + pg.addTexts(listOf("foo", "bar")).size shouldBe 2 + } + + "similaritySearchByVector should return both documents" { + pg.similaritySearchByVector(Embedding(listOf(4.0f, 5.0f, 6.0f)), 2) shouldBe listOf( + Document("bar"), + Document("foo") + ) + } + + "addDocuments should return a list of 2 elements" { + pg.addDocuments(listOf(Document("foo"), Document("bar"))).size shouldBe 2 + } + + "similaritySearch should return 2 documents" { + pg.similaritySearch("foo", 2).size shouldBe 2 + } + + "similaritySearch should fail when embedding vector is empty" { + assertThrows { + pg.similaritySearch("baz", 2) + }.message shouldBe "Embedding for text: 'baz', has not been properly generated" + } + + "similaritySearchByVector should return document" { + pg.similaritySearchByVector(Embedding(listOf(1.0f, 2.0f, 3.0f)), 1) shouldBe listOf(Document("foo")) + } +})