Skip to content

Commit

Permalink
gcp embeddings (#358)
Browse files Browse the repository at this point in the history
* basic implementation of gcp embeddings via rest api (yet to be tested)

* complete impl of gcp embeddings; add GCP_TOKEN env var

* remove enum EmbeddingModel (no usages and not generic)

* apply spotless

* remove scala test for removed EmbeddingModel class

* Update integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/GcpChat.kt

Co-authored-by: Raúl Raja Martínez <raulraja@gmail.com>

* return token usage info for gcp embedding

---------

Co-authored-by: ron <ron.spannagel@47deg.com>
Co-authored-by: Raúl Raja Martínez <raulraja@gmail.com>
  • Loading branch information
3 people authored Aug 30, 2023
1 parent 547a1b2 commit fd12b91
Show file tree
Hide file tree
Showing 13 changed files with 139 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ sealed class AIError @JvmOverloads constructor(message: String, cause: Throwable
data class OpenAI(val errors: NonEmptyList<String>) :
Env("OpenAI Environment not found: ${errors.all.joinToString("\n")}")

data class GCP(val errors: NonEmptyList<String>) :
Env("GCP Environment not found: ${errors.all.joinToString("\n")}")

data class HuggingFace(val errors: NonEmptyList<String>) :
Env("HuggingFace Environment not found: ${errors.all.joinToString("\n")}")
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ package com.xebia.functional.xef.llm.models.embeddings
import kotlin.jvm.JvmInline
import kotlin.jvm.JvmStatic

data class RequestConfig(val model: EmbeddingModel, val user: User) {
data class RequestConfig(val user: User) {
companion object {
@JvmStatic
fun apply(model: EmbeddingModel, userId: String): RequestConfig {
return RequestConfig(model, User(userId))
fun apply(userId: String): RequestConfig {
return RequestConfig(User(userId))
}

@JvmInline value class User(val id: String)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import arrow.atomic.getAndUpdate
import arrow.atomic.update
import com.xebia.functional.xef.embeddings.Embedding
import com.xebia.functional.xef.embeddings.Embeddings
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingModel
import com.xebia.functional.xef.llm.models.embeddings.RequestConfig
import kotlin.math.sqrt

Expand All @@ -26,8 +25,7 @@ private constructor(private val embeddings: Embeddings, private val state: Atomi
VectorStore {
constructor(embeddings: Embeddings) : this(embeddings, Atomic(State.empty()))

private val requestConfig =
RequestConfig(EmbeddingModel.TEXT_EMBEDDING_ADA_002, RequestConfig.Companion.User("user"))
private val requestConfig = RequestConfig(RequestConfig.Companion.User("user"))

override suspend fun addMemories(memories: List<Memory>) {
state.update { prevState ->
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,32 @@
package com.xebia.functional.xef.conversation.gpc

import arrow.core.nonEmptyListOf
import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.conversation.autoClose
import com.xebia.functional.xef.conversation.llm.openai.OpenAI
import com.xebia.functional.xef.env.getenv
import com.xebia.functional.xef.gcp.GcpChat
import com.xebia.functional.xef.gcp.GcpEmbeddings
import com.xebia.functional.xef.llm.models.embeddings.RequestConfig
import com.xebia.functional.xef.prompt.Prompt

suspend fun main() {
OpenAI.conversation {
val token =
getenv("GCP_TOKEN") ?: throw AIError.Env.GCP(nonEmptyListOf("missing GCP_TOKEN env var"))

val gcp =
autoClose(
GcpChat("us-central1-aiplatform.googleapis.com", "xef-demo", "codechat-bison@001", "token")
)
GcpChat("us-central1-aiplatform.googleapis.com", "xefdemo", "codechat-bison@001", token)
.let(::autoClose)
val gcpEmbeddingModel =
GcpChat("us-central1-aiplatform.googleapis.com", "xefdemo", "textembedding-gecko", token)
.let(::autoClose)

val embeddingResult =
GcpEmbeddings(gcpEmbeddingModel)
.embedQuery("strawberry donuts", RequestConfig(RequestConfig.Companion.User("user")))
println(embeddingResult)

while (true) {
print("\n🤖 Enter your question: ")
val userInput = readlnOrNull() ?: break
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ import com.xebia.functional.tokenizer.EncodingType
import com.xebia.functional.tokenizer.ModelType
import com.xebia.functional.xef.llm.Chat
import com.xebia.functional.xef.llm.Completion
import com.xebia.functional.xef.llm.Embeddings
import com.xebia.functional.xef.llm.models.chat.*
import com.xebia.functional.xef.llm.models.embeddings.Embedding
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingResult
import com.xebia.functional.xef.llm.models.text.CompletionChoice
import com.xebia.functional.xef.llm.models.text.CompletionRequest
import com.xebia.functional.xef.llm.models.text.CompletionResult
Expand All @@ -17,7 +21,7 @@ import kotlinx.uuid.generateUUID

@OptIn(ExperimentalStdlibApi::class)
class GcpChat(apiEndpoint: String, projectId: String, modelId: String, token: String) :
Chat, Completion, AutoCloseable {
Chat, Completion, AutoCloseable, Embeddings {
private val client: GcpClient = GcpClient(apiEndpoint, projectId, modelId, token)

override val name: String = client.modelId
Expand All @@ -38,7 +42,7 @@ class GcpChat(apiEndpoint: String, projectId: String, modelId: String, token: St
getTimeMillis(),
client.modelId,
listOf(CompletionChoice(response, 0, null, null)),
Usage.ZERO
Usage.ZERO, // TODO: token usage - no information about usage provided by GCP codechat model
)
}

Expand All @@ -58,7 +62,7 @@ class GcpChat(apiEndpoint: String, projectId: String, modelId: String, token: St
client.modelId,
getTimeMillis().toInt(),
client.modelId,
Usage.ZERO,
Usage.ZERO, // TODO: token usage - no information about usage provided by GCP
listOf(Choice(Message(Role.ASSISTANT, response, Role.ASSISTANT.name), null, 0)),
)
}
Expand All @@ -83,12 +87,32 @@ class GcpChat(apiEndpoint: String, projectId: String, modelId: String, token: St
getTimeMillis().toInt(),
client.modelId,
listOf(ChatChunk(delta = ChatDelta(Role.ASSISTANT, response))),
Usage.ZERO,
Usage
.ZERO, // TODO: token usage - no information about usage provided by GCP for codechat
// model
)
)
}
}

override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult {
fun requestToEmbedding(index: Int, it: GcpClient.EmbeddingPredictions): Embedding =
Embedding("embedding", it.embeddings.values.map(Double::toFloat), index = index)

val response = client.embeddings(request)
return EmbeddingResult(
data = response.predictions.mapIndexed(::requestToEmbedding),
usage = usage(response),
)
}

private fun usage(response: GcpClient.EmbeddingResponse) =
Usage(
totalTokens = response.predictions.sumOf { it.embeddings.statistics.tokenCount },
promptTokens = null,
completionTokens = null,
)

override fun tokensFromMessages(messages: List<Message>): Int = 0

override fun close() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.xebia.functional.xef.gcp
import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.conversation.AutoClose
import com.xebia.functional.xef.conversation.autoClose
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest
import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.plugins.*
Expand All @@ -16,6 +17,7 @@ import io.ktor.http.HttpStatusCode
import io.ktor.http.contentType
import io.ktor.http.isSuccess
import io.ktor.serialization.kotlinx.json.json
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json

Expand Down Expand Up @@ -112,6 +114,58 @@ class GcpClient(
else throw GcpClientException(response.status, response.bodyAsText())
}

@Serializable
private data class GcpEmbeddingRequest(
val instances: List<GcpEmbeddingInstance>,
)

@Serializable
private data class GcpEmbeddingInstance(
val content: String,
)

@Serializable
data class EmbeddingResponse(
val predictions: List<EmbeddingPredictions>,
)

@Serializable
data class EmbeddingPredictions(
val embeddings: PredictionEmbeddings,
)

@Serializable
data class PredictionEmbeddings(
val statistics: EmbeddingStatistics,
val values: List<Double>,
)

@Serializable
data class EmbeddingStatistics(
val truncated: Boolean,
@SerialName("token_count") val tokenCount: Int,
)

suspend fun embeddings(request: EmbeddingRequest): EmbeddingResponse {
val body =
GcpEmbeddingRequest(
instances = request.input.map(::GcpEmbeddingInstance),
)
val response =
http.post(
"https://$apiEndpoint/v1/projects/$projectId/locations/us-central1/publishers/google/models/$modelId:predict"
) {
header("Authorization", "Bearer $token")
contentType(ContentType.Application.Json)
setBody(body)
}
return if (response.status.isSuccess()) {
val embedding = response.body<EmbeddingResponse>()
if (embedding.predictions.isEmpty()) throw AIError.NoResponse()
embedding
} else throw GcpClientException(response.status, response.bodyAsText())
}

class GcpClientException(val httpStatusCode: HttpStatusCode, val error: String) :
IllegalStateException("$httpStatusCode: $error")

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package com.xebia.functional.xef.gcp

import arrow.fx.coroutines.parMap
import com.xebia.functional.xef.embeddings.Embedding
import com.xebia.functional.xef.embeddings.Embeddings
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest
import com.xebia.functional.xef.llm.models.embeddings.RequestConfig

class GcpEmbeddings(private val gcpClient: com.xebia.functional.xef.llm.Embeddings) : Embeddings {
override suspend fun embedDocuments(
texts: List<String>,
chunkSize: Int?,
requestConfig: RequestConfig
): List<Embedding> {
suspend fun createEmbeddings(texts: List<String>): List<Embedding> {
val req = EmbeddingRequest(gcpClient.name, texts, requestConfig.user.id)
return gcpClient.createEmbeddings(req).data.map { Embedding(it.embedding) }
}
val lists: List<List<Embedding>> =
if (texts.isEmpty()) emptyList()
else texts.chunked(chunkSize ?: 400).parMap { createEmbeddings(it) }
return lists.flatten()
}

override suspend fun embedQuery(text: String, requestConfig: RequestConfig): List<Embedding> =
if (text.isNotEmpty()) embedDocuments(listOf(text), null, requestConfig) else emptyList()
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import com.xebia.functional.xef.embeddings.Embedding
import com.xebia.functional.xef.embeddings.Embeddings
import com.xebia.functional.xef.llm.models.chat.Message
import com.xebia.functional.xef.llm.models.chat.Role
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingModel
import com.xebia.functional.xef.llm.models.embeddings.RequestConfig
import org.apache.lucene.analysis.standard.StandardAnalyzer
import org.apache.lucene.document.Document
Expand All @@ -30,8 +29,7 @@ open class Lucene(
private val similarity: VectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN
) : VectorStore, AutoCloseable {

private val requestConfig =
RequestConfig(EmbeddingModel.TEXT_EMBEDDING_ADA_002, RequestConfig.Companion.User("user"))
private val requestConfig = RequestConfig(RequestConfig.Companion.User("user"))

override suspend fun addMemories(memories: List<Memory>) {
memories.forEach {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import com.xebia.functional.xef.embeddings.Embedding
import com.xebia.functional.xef.embeddings.Embeddings
import com.xebia.functional.xef.llm.models.chat.Message
import com.xebia.functional.xef.llm.models.chat.Role
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingModel
import com.xebia.functional.xef.llm.models.embeddings.RequestConfig
import com.xebia.functional.xef.store.ConversationId
import com.xebia.functional.xef.store.Memory
Expand Down Expand Up @@ -50,8 +49,7 @@ class PGVectorStoreSpec :
collectionName = "test_collection",
distanceStrategy = PGDistanceStrategy.Euclidean,
preDeleteCollection = false,
requestConfig =
RequestConfig(EmbeddingModel.TEXT_EMBEDDING_ADA_002, RequestConfig.Companion.User("user")),
requestConfig = RequestConfig(RequestConfig.Companion.User("user")),
chunkSize = null
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class OpenAIEmbeddings(private val oaiClient: com.xebia.functional.xef.llm.Embed
requestConfig: RequestConfig
): List<Embedding> {
suspend fun createEmbeddings(texts: List<String>): List<Embedding> {
val req = EmbeddingRequest(requestConfig.model.modelName, texts, requestConfig.user.id)
val req = EmbeddingRequest(oaiClient.name, texts, requestConfig.user.id)
return oaiClient.createEmbeddings(req).data.map { Embedding(it.embedding) }
}
val lists: List<List<Embedding>> =
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package com.xebia.functional.xef.server.services
import com.xebia.functional.xef.conversation.autoClose
import com.xebia.functional.xef.conversation.llm.openai.OpenAI
import com.xebia.functional.xef.conversation.llm.openai.OpenAIEmbeddings
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingModel
import com.xebia.functional.xef.llm.models.embeddings.RequestConfig
import com.xebia.functional.xef.server.http.routes.Provider
import com.xebia.functional.xef.store.PGVectorStore
Expand Down Expand Up @@ -69,7 +68,6 @@ class PostgresXefService(
Provider.OPENAI -> OpenAIEmbeddings(OpenAI(token).DEFAULT_EMBEDDING)
else -> OpenAIEmbeddings(OpenAI(token).DEFAULT_EMBEDDING)
}
val embeddingModel = EmbeddingModel.TEXT_EMBEDDING_ADA_002

return PGVectorStore(
vectorSize = config.vectorSize,
Expand All @@ -80,7 +78,6 @@ class PostgresXefService(
preDeleteCollection = config.preDeleteCollection,
requestConfig =
RequestConfig(
model = embeddingModel,
user = RequestConfig.Companion.User("user")
),
chunkSize = config.chunkSize
Expand Down

0 comments on commit fd12b91

Please sign in to comment.