From fd12b91dd782293ae21a665eb6b50ffd6bcdd6c6 Mon Sep 17 00:00:00 2001 From: Ron S <47056605+Intex32@users.noreply.github.com> Date: Wed, 30 Aug 2023 14:53:30 +0200 Subject: [PATCH] gcp embeddings (#358) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * basic implementation of gcp embeddings via rest api (yet to be tested) * complete impl of gcp embeddings; add GCP_TOKEN env var * remove enum EmbeddingModel (no usages and not generic) * apply spotless * remove scala test for removed EmbeddingModel class * Update integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/GcpChat.kt Co-authored-by: Raúl Raja Martínez * return token usage info for gcp embedding --------- Co-authored-by: ron Co-authored-by: Raúl Raja Martínez --- .../com/xebia/functional/xef/AIError.kt | 3 ++ .../llm/models/embeddings/EmbeddingModel.kt | 5 -- .../llm/models/embeddings/RequestConfig.kt | 6 +-- .../functional/xef/store/LocalVectorStore.kt | 4 +- .../functional/xef/conversation/gpc/Chat.kt | 23 ++++++-- .../com/xebia/functional/xef/gcp/GcpChat.kt | 32 +++++++++-- .../com/xebia/functional/xef/gcp/GcpClient.kt | 54 +++++++++++++++++++ .../xebia/functional/xef/gcp/GcpEmbeddings.kt | 27 ++++++++++ .../com/xebia/functional/xef/store/Lucene.kt | 4 +- .../src/test/kotlin/xef/PGVectorStoreSpec.kt | 4 +- .../llm/openai/OpenAIEmbeddings.kt | 2 +- .../models/embeddings/RequestConfigSpec.scala | 11 ---- .../xef/server/services/PostgresXefService.kt | 3 -- 13 files changed, 139 insertions(+), 39 deletions(-) delete mode 100644 core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/embeddings/EmbeddingModel.kt create mode 100644 integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/GcpEmbeddings.kt delete mode 100644 scala/src/test/scala/com/xebia/functional/xef/llm/models/embeddings/RequestConfigSpec.scala diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/AIError.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/AIError.kt index bcb40e5bf..cc30b4d78 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/AIError.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/AIError.kt @@ -21,6 +21,9 @@ sealed class AIError @JvmOverloads constructor(message: String, cause: Throwable data class OpenAI(val errors: NonEmptyList) : Env("OpenAI Environment not found: ${errors.all.joinToString("\n")}") + data class GCP(val errors: NonEmptyList) : + Env("GCP Environment not found: ${errors.all.joinToString("\n")}") + data class HuggingFace(val errors: NonEmptyList) : Env("HuggingFace Environment not found: ${errors.all.joinToString("\n")}") } diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/embeddings/EmbeddingModel.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/embeddings/EmbeddingModel.kt deleted file mode 100644 index acfc85cc6..000000000 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/embeddings/EmbeddingModel.kt +++ /dev/null @@ -1,5 +0,0 @@ -package com.xebia.functional.xef.llm.models.embeddings - -enum class EmbeddingModel(val modelName: String) { - TEXT_EMBEDDING_ADA_002("text-embedding-ada-002") -} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/embeddings/RequestConfig.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/embeddings/RequestConfig.kt index 25709a811..4b66b5031 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/embeddings/RequestConfig.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/embeddings/RequestConfig.kt @@ -3,11 +3,11 @@ package com.xebia.functional.xef.llm.models.embeddings import kotlin.jvm.JvmInline import kotlin.jvm.JvmStatic -data class RequestConfig(val model: EmbeddingModel, val user: User) { +data class RequestConfig(val user: User) { companion object { @JvmStatic - fun apply(model: EmbeddingModel, userId: String): RequestConfig { - return RequestConfig(model, User(userId)) + fun apply(userId: String): RequestConfig { + return RequestConfig(User(userId)) } @JvmInline value class User(val id: String) diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/store/LocalVectorStore.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/store/LocalVectorStore.kt index d1cc6b6cd..0d66ba8c1 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/store/LocalVectorStore.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/store/LocalVectorStore.kt @@ -5,7 +5,6 @@ import arrow.atomic.getAndUpdate import arrow.atomic.update import com.xebia.functional.xef.embeddings.Embedding import com.xebia.functional.xef.embeddings.Embeddings -import com.xebia.functional.xef.llm.models.embeddings.EmbeddingModel import com.xebia.functional.xef.llm.models.embeddings.RequestConfig import kotlin.math.sqrt @@ -26,8 +25,7 @@ private constructor(private val embeddings: Embeddings, private val state: Atomi VectorStore { constructor(embeddings: Embeddings) : this(embeddings, Atomic(State.empty())) - private val requestConfig = - RequestConfig(EmbeddingModel.TEXT_EMBEDDING_ADA_002, RequestConfig.Companion.User("user")) + private val requestConfig = RequestConfig(RequestConfig.Companion.User("user")) override suspend fun addMemories(memories: List) { state.update { prevState -> diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/conversation/gpc/Chat.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/conversation/gpc/Chat.kt index 284de482d..3c2bca471 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/conversation/gpc/Chat.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/conversation/gpc/Chat.kt @@ -1,15 +1,32 @@ package com.xebia.functional.xef.conversation.gpc +import arrow.core.nonEmptyListOf +import com.xebia.functional.xef.AIError +import com.xebia.functional.xef.conversation.autoClose import com.xebia.functional.xef.conversation.llm.openai.OpenAI +import com.xebia.functional.xef.env.getenv import com.xebia.functional.xef.gcp.GcpChat +import com.xebia.functional.xef.gcp.GcpEmbeddings +import com.xebia.functional.xef.llm.models.embeddings.RequestConfig import com.xebia.functional.xef.prompt.Prompt suspend fun main() { OpenAI.conversation { + val token = + getenv("GCP_TOKEN") ?: throw AIError.Env.GCP(nonEmptyListOf("missing GCP_TOKEN env var")) + val gcp = - autoClose( - GcpChat("us-central1-aiplatform.googleapis.com", "xef-demo", "codechat-bison@001", "token") - ) + GcpChat("us-central1-aiplatform.googleapis.com", "xefdemo", "codechat-bison@001", token) + .let(::autoClose) + val gcpEmbeddingModel = + GcpChat("us-central1-aiplatform.googleapis.com", "xefdemo", "textembedding-gecko", token) + .let(::autoClose) + + val embeddingResult = + GcpEmbeddings(gcpEmbeddingModel) + .embedQuery("strawberry donuts", RequestConfig(RequestConfig.Companion.User("user"))) + println(embeddingResult) + while (true) { print("\n🤖 Enter your question: ") val userInput = readlnOrNull() ?: break diff --git a/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/GcpChat.kt b/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/GcpChat.kt index b6dc4c9c2..4e866cdfe 100644 --- a/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/GcpChat.kt +++ b/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/GcpChat.kt @@ -4,7 +4,11 @@ import com.xebia.functional.tokenizer.EncodingType import com.xebia.functional.tokenizer.ModelType import com.xebia.functional.xef.llm.Chat import com.xebia.functional.xef.llm.Completion +import com.xebia.functional.xef.llm.Embeddings import com.xebia.functional.xef.llm.models.chat.* +import com.xebia.functional.xef.llm.models.embeddings.Embedding +import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest +import com.xebia.functional.xef.llm.models.embeddings.EmbeddingResult import com.xebia.functional.xef.llm.models.text.CompletionChoice import com.xebia.functional.xef.llm.models.text.CompletionRequest import com.xebia.functional.xef.llm.models.text.CompletionResult @@ -17,7 +21,7 @@ import kotlinx.uuid.generateUUID @OptIn(ExperimentalStdlibApi::class) class GcpChat(apiEndpoint: String, projectId: String, modelId: String, token: String) : - Chat, Completion, AutoCloseable { + Chat, Completion, AutoCloseable, Embeddings { private val client: GcpClient = GcpClient(apiEndpoint, projectId, modelId, token) override val name: String = client.modelId @@ -38,7 +42,7 @@ class GcpChat(apiEndpoint: String, projectId: String, modelId: String, token: St getTimeMillis(), client.modelId, listOf(CompletionChoice(response, 0, null, null)), - Usage.ZERO + Usage.ZERO, // TODO: token usage - no information about usage provided by GCP codechat model ) } @@ -58,7 +62,7 @@ class GcpChat(apiEndpoint: String, projectId: String, modelId: String, token: St client.modelId, getTimeMillis().toInt(), client.modelId, - Usage.ZERO, + Usage.ZERO, // TODO: token usage - no information about usage provided by GCP listOf(Choice(Message(Role.ASSISTANT, response, Role.ASSISTANT.name), null, 0)), ) } @@ -83,12 +87,32 @@ class GcpChat(apiEndpoint: String, projectId: String, modelId: String, token: St getTimeMillis().toInt(), client.modelId, listOf(ChatChunk(delta = ChatDelta(Role.ASSISTANT, response))), - Usage.ZERO, + Usage + .ZERO, // TODO: token usage - no information about usage provided by GCP for codechat + // model ) ) } } + override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult { + fun requestToEmbedding(index: Int, it: GcpClient.EmbeddingPredictions): Embedding = + Embedding("embedding", it.embeddings.values.map(Double::toFloat), index = index) + + val response = client.embeddings(request) + return EmbeddingResult( + data = response.predictions.mapIndexed(::requestToEmbedding), + usage = usage(response), + ) + } + + private fun usage(response: GcpClient.EmbeddingResponse) = + Usage( + totalTokens = response.predictions.sumOf { it.embeddings.statistics.tokenCount }, + promptTokens = null, + completionTokens = null, + ) + override fun tokensFromMessages(messages: List): Int = 0 override fun close() { diff --git a/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/GcpClient.kt b/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/GcpClient.kt index 85f1dfb60..410a5a20a 100644 --- a/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/GcpClient.kt +++ b/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/GcpClient.kt @@ -3,6 +3,7 @@ package com.xebia.functional.xef.gcp import com.xebia.functional.xef.AIError import com.xebia.functional.xef.conversation.AutoClose import com.xebia.functional.xef.conversation.autoClose +import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest import io.ktor.client.HttpClient import io.ktor.client.call.body import io.ktor.client.plugins.* @@ -16,6 +17,7 @@ import io.ktor.http.HttpStatusCode import io.ktor.http.contentType import io.ktor.http.isSuccess import io.ktor.serialization.kotlinx.json.json +import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable import kotlinx.serialization.json.Json @@ -112,6 +114,58 @@ class GcpClient( else throw GcpClientException(response.status, response.bodyAsText()) } + @Serializable + private data class GcpEmbeddingRequest( + val instances: List, + ) + + @Serializable + private data class GcpEmbeddingInstance( + val content: String, + ) + + @Serializable + data class EmbeddingResponse( + val predictions: List, + ) + + @Serializable + data class EmbeddingPredictions( + val embeddings: PredictionEmbeddings, + ) + + @Serializable + data class PredictionEmbeddings( + val statistics: EmbeddingStatistics, + val values: List, + ) + + @Serializable + data class EmbeddingStatistics( + val truncated: Boolean, + @SerialName("token_count") val tokenCount: Int, + ) + + suspend fun embeddings(request: EmbeddingRequest): EmbeddingResponse { + val body = + GcpEmbeddingRequest( + instances = request.input.map(::GcpEmbeddingInstance), + ) + val response = + http.post( + "https://$apiEndpoint/v1/projects/$projectId/locations/us-central1/publishers/google/models/$modelId:predict" + ) { + header("Authorization", "Bearer $token") + contentType(ContentType.Application.Json) + setBody(body) + } + return if (response.status.isSuccess()) { + val embedding = response.body() + if (embedding.predictions.isEmpty()) throw AIError.NoResponse() + embedding + } else throw GcpClientException(response.status, response.bodyAsText()) + } + class GcpClientException(val httpStatusCode: HttpStatusCode, val error: String) : IllegalStateException("$httpStatusCode: $error") diff --git a/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/GcpEmbeddings.kt b/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/GcpEmbeddings.kt new file mode 100644 index 000000000..8ffd4a2b8 --- /dev/null +++ b/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/GcpEmbeddings.kt @@ -0,0 +1,27 @@ +package com.xebia.functional.xef.gcp + +import arrow.fx.coroutines.parMap +import com.xebia.functional.xef.embeddings.Embedding +import com.xebia.functional.xef.embeddings.Embeddings +import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest +import com.xebia.functional.xef.llm.models.embeddings.RequestConfig + +class GcpEmbeddings(private val gcpClient: com.xebia.functional.xef.llm.Embeddings) : Embeddings { + override suspend fun embedDocuments( + texts: List, + chunkSize: Int?, + requestConfig: RequestConfig + ): List { + suspend fun createEmbeddings(texts: List): List { + val req = EmbeddingRequest(gcpClient.name, texts, requestConfig.user.id) + return gcpClient.createEmbeddings(req).data.map { Embedding(it.embedding) } + } + val lists: List> = + if (texts.isEmpty()) emptyList() + else texts.chunked(chunkSize ?: 400).parMap { createEmbeddings(it) } + return lists.flatten() + } + + override suspend fun embedQuery(text: String, requestConfig: RequestConfig): List = + if (text.isNotEmpty()) embedDocuments(listOf(text), null, requestConfig) else emptyList() +} diff --git a/integrations/lucene/src/main/kotlin/com/xebia/functional/xef/store/Lucene.kt b/integrations/lucene/src/main/kotlin/com/xebia/functional/xef/store/Lucene.kt index 1bba3de5e..a8f24851d 100644 --- a/integrations/lucene/src/main/kotlin/com/xebia/functional/xef/store/Lucene.kt +++ b/integrations/lucene/src/main/kotlin/com/xebia/functional/xef/store/Lucene.kt @@ -4,7 +4,6 @@ import com.xebia.functional.xef.embeddings.Embedding import com.xebia.functional.xef.embeddings.Embeddings import com.xebia.functional.xef.llm.models.chat.Message import com.xebia.functional.xef.llm.models.chat.Role -import com.xebia.functional.xef.llm.models.embeddings.EmbeddingModel import com.xebia.functional.xef.llm.models.embeddings.RequestConfig import org.apache.lucene.analysis.standard.StandardAnalyzer import org.apache.lucene.document.Document @@ -30,8 +29,7 @@ open class Lucene( private val similarity: VectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN ) : VectorStore, AutoCloseable { - private val requestConfig = - RequestConfig(EmbeddingModel.TEXT_EMBEDDING_ADA_002, RequestConfig.Companion.User("user")) + private val requestConfig = RequestConfig(RequestConfig.Companion.User("user")) override suspend fun addMemories(memories: List) { memories.forEach { diff --git a/integrations/postgresql/src/test/kotlin/xef/PGVectorStoreSpec.kt b/integrations/postgresql/src/test/kotlin/xef/PGVectorStoreSpec.kt index 8c37966b5..4df7e40bc 100644 --- a/integrations/postgresql/src/test/kotlin/xef/PGVectorStoreSpec.kt +++ b/integrations/postgresql/src/test/kotlin/xef/PGVectorStoreSpec.kt @@ -4,7 +4,6 @@ import com.xebia.functional.xef.embeddings.Embedding import com.xebia.functional.xef.embeddings.Embeddings import com.xebia.functional.xef.llm.models.chat.Message import com.xebia.functional.xef.llm.models.chat.Role -import com.xebia.functional.xef.llm.models.embeddings.EmbeddingModel import com.xebia.functional.xef.llm.models.embeddings.RequestConfig import com.xebia.functional.xef.store.ConversationId import com.xebia.functional.xef.store.Memory @@ -50,8 +49,7 @@ class PGVectorStoreSpec : collectionName = "test_collection", distanceStrategy = PGDistanceStrategy.Euclidean, preDeleteCollection = false, - requestConfig = - RequestConfig(EmbeddingModel.TEXT_EMBEDDING_ADA_002, RequestConfig.Companion.User("user")), + requestConfig = RequestConfig(RequestConfig.Companion.User("user")), chunkSize = null ) diff --git a/openai/src/commonMain/kotlin/com/xebia/functional/xef/conversation/llm/openai/OpenAIEmbeddings.kt b/openai/src/commonMain/kotlin/com/xebia/functional/xef/conversation/llm/openai/OpenAIEmbeddings.kt index 9c9ac3926..c1646a906 100644 --- a/openai/src/commonMain/kotlin/com/xebia/functional/xef/conversation/llm/openai/OpenAIEmbeddings.kt +++ b/openai/src/commonMain/kotlin/com/xebia/functional/xef/conversation/llm/openai/OpenAIEmbeddings.kt @@ -15,7 +15,7 @@ class OpenAIEmbeddings(private val oaiClient: com.xebia.functional.xef.llm.Embed requestConfig: RequestConfig ): List { suspend fun createEmbeddings(texts: List): List { - val req = EmbeddingRequest(requestConfig.model.modelName, texts, requestConfig.user.id) + val req = EmbeddingRequest(oaiClient.name, texts, requestConfig.user.id) return oaiClient.createEmbeddings(req).data.map { Embedding(it.embedding) } } val lists: List> = diff --git a/scala/src/test/scala/com/xebia/functional/xef/llm/models/embeddings/RequestConfigSpec.scala b/scala/src/test/scala/com/xebia/functional/xef/llm/models/embeddings/RequestConfigSpec.scala deleted file mode 100644 index dffe743fe..000000000 --- a/scala/src/test/scala/com/xebia/functional/xef/llm/models/embeddings/RequestConfigSpec.scala +++ /dev/null @@ -1,11 +0,0 @@ -package com.xebia.functional.xef.llm.models.embeddings - -import munit.FunSuite - -class RequestConfigSpec extends FunSuite: - - test("Should create a RequestConfig through the static apply method") { - val requestConfig = RequestConfig(EmbeddingModel.TEXT_EMBEDDING_ADA_002, "user") - - assertEquals(requestConfig.getModel, EmbeddingModel.TEXT_EMBEDDING_ADA_002) - } diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/services/PostgresXefService.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/services/PostgresXefService.kt index 2876f0bbc..0060c9b04 100644 --- a/server/src/main/kotlin/com/xebia/functional/xef/server/services/PostgresXefService.kt +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/services/PostgresXefService.kt @@ -3,7 +3,6 @@ package com.xebia.functional.xef.server.services import com.xebia.functional.xef.conversation.autoClose import com.xebia.functional.xef.conversation.llm.openai.OpenAI import com.xebia.functional.xef.conversation.llm.openai.OpenAIEmbeddings -import com.xebia.functional.xef.llm.models.embeddings.EmbeddingModel import com.xebia.functional.xef.llm.models.embeddings.RequestConfig import com.xebia.functional.xef.server.http.routes.Provider import com.xebia.functional.xef.store.PGVectorStore @@ -69,7 +68,6 @@ class PostgresXefService( Provider.OPENAI -> OpenAIEmbeddings(OpenAI(token).DEFAULT_EMBEDDING) else -> OpenAIEmbeddings(OpenAI(token).DEFAULT_EMBEDDING) } - val embeddingModel = EmbeddingModel.TEXT_EMBEDDING_ADA_002 return PGVectorStore( vectorSize = config.vectorSize, @@ -80,7 +78,6 @@ class PostgresXefService( preDeleteCollection = config.preDeleteCollection, requestConfig = RequestConfig( - model = embeddingModel, user = RequestConfig.Companion.User("user") ), chunkSize = config.chunkSize