Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JDBC Postgres VectorStore #11

Merged
merged 9 commits into from
Apr 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ kotlin {
implementation(libs.kotlinx.serialization.json)
implementation(libs.bundles.ktor.client)
implementation(libs.okio)
implementation(libs.uuid)
implementation(libs.klogging)
}
}

Expand All @@ -63,9 +65,17 @@ kotlin {
implementation(libs.kotest.assertions.arrow)
}
}
val jvmMain by getting {
dependencies {
implementation(libs.hikari)
implementation(libs.postgresql)
}
}
val jvmTest by getting {
dependencies {
implementation(libs.kotest.junit5)
implementation(libs.kotest.testcontainers)
implementation(libs.testcontainers.postgresql)
}
}
}
Expand Down
12 changes: 12 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@ ktor = "2.2.2"
spotless = "6.18.0"
okio = "3.3.0"
kotest = "5.5.4"
kotest-testcontainers = "1.3.4"
kotest-arrow = "1.3.0"
klogging = "4.0.0-beta-22"
uuid = "0.0.18"
postgresql = "42.5.1"
testcontainers = "1.17.6"
hikari = "5.0.1"

[libraries]
arrow-fx = { module = "io.arrow-kt:arrow-fx-coroutines", version.ref = "arrow" }
Expand All @@ -23,7 +29,13 @@ kotest-assertions = { module = "io.kotest:kotest-assertions-core", version.ref =
kotest-framework = { module = "io.kotest:kotest-framework-engine", version.ref = "kotest" }
kotest-property = { module = "io.kotest:kotest-property", version.ref = "kotest" }
kotest-junit5 = { module = "io.kotest:kotest-runner-junit5", version.ref = "kotest" }
kotest-testcontainers = { module = "io.kotest.extensions:kotest-extensions-testcontainers", version.ref = "kotest-testcontainers" }
kotest-assertions-arrow = { module = "io.kotest.extensions:kotest-assertions-arrow", version.ref = "kotest-arrow" }
uuid = { module = "app.softwork:kotlinx-uuid-core", version.ref = "uuid" }
klogging = { module = "io.github.oshai:kotlin-logging", version.ref = "klogging" }
hikari = { module = "com.zaxxer:HikariCP", version.ref = "hikari" }
postgresql = { module = "org.postgresql:postgresql", version.ref = "postgresql" }
testcontainers-postgresql = { module = "org.testcontainers:postgresql", version.ref = "testcontainers" }

[bundles]
ktor-client = [
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package com.xebia.functional.embeddings

import com.xebia.functional.llm.openai.RequestConfig

data class Embedding(val data: List<Float>)

interface Embeddings {
suspend fun embedDocuments(texts: List<String>, chunkSize: Int?, requestConfig: RequestConfig): List<Embedding>
suspend fun embedQuery(text: String, requestConfig: RequestConfig): List<Embedding>

companion object
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package com.xebia.functional.embeddings

import arrow.fx.coroutines.parMap
import arrow.resilience.retry
import com.xebia.functional.env.OpenAIConfig
import com.xebia.functional.llm.openai.EmbeddingRequest
import com.xebia.functional.llm.openai.OpenAIClient
import com.xebia.functional.llm.openai.RequestConfig
import io.github.oshai.KLogger
import kotlin.time.ExperimentalTime

@ExperimentalTime
class OpenAIEmbeddings(
private val config: OpenAIConfig,
private val oaiClient: OpenAIClient,
private val logger: KLogger
) : Embeddings {

override suspend fun embedDocuments(
texts: List<String>,
chunkSize: Int?,
requestConfig: RequestConfig
): List<Embedding> =
chunkedEmbedDocuments(texts, chunkSize ?: config.chunkSize, requestConfig)

override suspend fun embedQuery(text: String, requestConfig: RequestConfig): List<Embedding> =
if (text.isNotEmpty()) embedDocuments(listOf(text), null, requestConfig) else emptyList()

private suspend fun chunkedEmbedDocuments(
texts: List<String>,
chunkSize: Int,
requestConfig: RequestConfig
): List<Embedding> =
if (texts.isEmpty()) emptyList()
else texts.chunked(chunkSize)
.parMap { createEmbeddingWithRetry(it, requestConfig) }
.flatten()

private suspend fun createEmbeddingWithRetry(texts: List<String>, requestConfig: RequestConfig): List<Embedding> =
kotlin.runCatching {
config.retryConfig.schedule()
.log { retriesSoFar, _ -> logger.warn { "Open AI call failed. So far we have retried $retriesSoFar times." } }
.retry {
oaiClient.createEmbeddings(EmbeddingRequest(requestConfig.model.name, texts, requestConfig.user.id))
.data.map { Embedding(it.embedding) }
}
}.getOrElse {
logger.warn { "Open AI call failed. Giving up after ${config.retryConfig.maxRetries} retries" }
throw it
}
}
6 changes: 2 additions & 4 deletions src/commonMain/kotlin/com/xebia/functional/env/config.kt
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@ data class Env(val openAI: OpenAIConfig, val huggingFace: HuggingFaceConfig)
data class OpenAIConfig(val token: String, val chunkSize: Int, val retryConfig: RetryConfig)

data class RetryConfig(val backoff: Duration, val maxRetries: Long) {
fun schedule(): Schedule<Throwable, Unit> =
fun schedule(): Schedule<Throwable, Long> =
Schedule.recurs<Throwable>(maxRetries)
.and(Schedule.exponential(backoff))
.jittered(0.75, 1.25)
.map { }
.zipLeft(Schedule.exponential<Throwable>(backoff).jittered(0.75, 1.25))
}

data class HuggingFaceConfig(val token: String, val baseUrl: KUrl)
Expand Down
15 changes: 14 additions & 1 deletion src/commonMain/kotlin/com/xebia/functional/llm/openai/models.kt
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
package com.xebia.functional.llm.openai

import kotlin.jvm.JvmInline
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

enum class EmbeddingModel(name: String) {
TextEmbeddingAda002("text-embedding-ada-002")
}

data class RequestConfig(val model: EmbeddingModel, val user: User) {
companion object {
@JvmInline
value class User(val id: String)
}
}


@Serializable
data class CompletionChoice(val text: String, val index: Int, val finishReason: String)

Expand Down Expand Up @@ -38,7 +51,7 @@ data class EmbeddingResult(
)

@Serializable
class Embedding(val `object`: String, val embedding: List<Double>, val index: Int)
class Embedding(val `object`: String, val embedding: List<Float>, val index: Int)

@Serializable
data class Usage(
Expand Down
3 changes: 3 additions & 0 deletions src/commonMain/kotlin/com/xebia/functional/model.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package com.xebia.functional

data class Document(val content: String)
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package com.xebia.functional.vectorstores

import com.xebia.functional.Document
import com.xebia.functional.embeddings.Embedding
import kotlin.jvm.JvmInline
import kotlinx.uuid.UUID

@JvmInline
value class DocumentVectorId(val id: UUID)

interface VectorStore {
/**
* Add texts to the vector store after running them through the embeddings
*
* @param texts list of text to add to the vector store
* @return a list of IDs from adding the texts to the vector store
*/
suspend fun addTexts(texts: List<String>): List<DocumentVectorId>

/**
* Add documents to the vector store after running them through the embeddings
*
* @param documents list of Documents to add to the vector store
* @return a list of IDs from adding the documents to the vector store
*/
suspend fun addDocuments(documents: List<Document>): List<DocumentVectorId>

/**
* Return the docs most similar to the query
*
* @param query text to use to search for similar documents
* @param limit number of documents to return
* @return a list of Documents most similar to query
*/
suspend fun similaritySearch(query: String, limit: Int): List<Document>

/**
* Return the docs most similar to the embedding
*
* @param embedding embedding vector to use to search for similar documents
* @param limit number of documents to return
* @return list of Documents most similar to the embedding
*/
suspend fun similaritySearchByVector(embedding: Embedding, limit: Int): List<Document>
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package com.xebia.functional.vectorstores

import kotlinx.uuid.UUID

data class PGCollection(val uuid: UUID, val collectionName: String)

enum class PGDistanceStrategy(val strategy: String) {
Euclidean("<->"), InnerProduct("<#>"), CosineDistance("<=>")
}

val createCollections: String =
"""CREATE TABLE langchain4k_collections (
uuid TEXT PRIMARY KEY,
name TEXT UNIQUE NOT NULL
);""".trimIndent()

val createEmbeddings: String =
"""CREATE TABLE langchain4k_embeddings (
uuid TEXT PRIMARY KEY,
collection_id TEXT REFERENCES langchain4k_collections(uuid),
embedding BLOB,
content TEXT
);""".trimIndent()

val addVectorExtension: String =
"CREATE EXTENSION IF NOT EXISTS vector;"

val createCollectionsTable: String =
"""CREATE TABLE IF NOT EXISTS langchain4k_collections (
uuid TEXT PRIMARY KEY,
name TEXT UNIQUE NOT NULL
);""".trimIndent()

fun createEmbeddingTable(vectorSize: Int): String =
"""CREATE TABLE IF NOT EXISTS langchain4k_embeddings (
uuid TEXT PRIMARY KEY,
collection_id TEXT REFERENCES langchain4k_collections(uuid),
embedding vector($vectorSize),
content TEXT
);""".trimIndent()

val addNewCollection: String =
"""INSERT INTO langchain4k_collections(uuid, name)
VALUES (?, ?)
ON CONFLICT DO NOTHING;""".trimIndent()

val deleteCollection: String =
"""DELETE FROM langchain4k_collections
WHERE uuid = ?;""".trimIndent()

val getCollection: String =
"""SELECT * FROM langchain4k_collections
WHERE name = ?;""".trimIndent()

val getCollectionById: String =
"""SELECT * FROM langchain4k_collections
WHERE uuid = ?;""".trimIndent()

val addNewDocument: String =
"""INSERT INTO langchain4k_embeddings(uuid, collection_id, embedding, content)
VALUES (?, ?, ?, ?);""".trimIndent()

val deleteCollectionDocs: String =
"""DELETE FROM langchain4k_embeddings
WHERE collection_id = ?;""".trimIndent()

val addNewText: String =
"""INSERT INTO langchain4k_embeddings(uuid, collection_id, embedding, content)
VALUES (?, ?, ?::vector, ?);""".trimIndent()

fun searchSimilarDocument(distance: PGDistanceStrategy): String =
"""SELECT content FROM langchain4k_embeddings
WHERE collection_id = ?
ORDER BY embedding
${distance.strategy} ?::vector
LIMIT ?;""".trimIndent()
23 changes: 23 additions & 0 deletions src/commonTest/kotlin/com/xebia/functional/embeddings/Mock.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.xebia.functional.embeddings

import com.xebia.functional.llm.openai.RequestConfig

fun Embeddings.Companion.mock(
embedDocuments: suspend (texts: List<String>, chunkSize: Int?, config: RequestConfig) -> List<Embedding> = { _, _, _ ->
listOf(Embedding(listOf(1.0f, 2.0f, 3.0f)), Embedding(listOf(4.0f, 5.0f, 6.0f)))
},
embedQuery: suspend (text: String, config: RequestConfig) -> List<Embedding> = { text, _ ->
when (text) {
"foo" -> listOf(Embedding(listOf(1.0f, 2.0f, 3.0f)))
"bar" -> listOf(Embedding(listOf(4.0f, 5.0f, 6.0f)))
"baz" -> listOf()
else -> listOf()
}
}
): Embeddings = object : Embeddings {
override suspend fun embedDocuments(texts: List<String>, chunkSize: Int?, requestConfig: RequestConfig): List<Embedding> =
embedDocuments(texts, chunkSize, requestConfig)

override suspend fun embedQuery(text: String, requestConfig: RequestConfig): List<Embedding> =
embedQuery(text, requestConfig)
}
Loading