Skip to content

Commit

Permalink
query fine tuned models (#469)
Browse files Browse the repository at this point in the history
* query fine tuned models (from branch #441-fine-tuning-oai)

* spotless

* clean build, make tests and mocks compile

* changes according to pr comments

---------

Co-authored-by: José Carlos Montañez <josecarlos.montanez@xebia.com>
  • Loading branch information
Intex32 and Montagon authored Oct 3, 2023
1 parent 91b76f9 commit 85fa396
Show file tree
Hide file tree
Showing 20 changed files with 153 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ sealed interface LLM : AutoCloseable {
val name
get() = modelType.name

/**
* Copies this instance and uses [modelType] for [LLM.modelType]. Has to return the most specific
* type of this instance!
*/
fun copy(modelType: ModelType): LLM

fun tokensFromMessages(
messages: List<Message>
): Int { // TODO: naive implementation with magic numbers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class TestEmbeddings : Embeddings {

override val modelType: ModelType = ModelType.TODO("test-embeddings")

override fun copy(modelType: ModelType) = TestEmbeddings()

override suspend fun embedDocuments(
texts: List<String>,
requestConfig: RequestConfig,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class TestFunctionsModel(

var requests: MutableList<FunChatCompletionRequest> = mutableListOf()

override fun copy(modelType: ModelType) = TestFunctionsModel(modelType, responses)

override fun tokensFromMessages(messages: List<Message>): Int {
return messages.sumOf { it.content.length }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class TestModel(

var requests: MutableList<ChatCompletionRequest> = mutableListOf()

override fun copy(modelType: ModelType) = TestModel(modelType, responses)

override suspend fun createChatCompletion(
request: ChatCompletionRequest
): ChatCompletionResponse {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.xebia.functional.xef.conversation.finetuning

import com.xebia.functional.xef.conversation.llm.openai.OpenAI
import com.xebia.functional.xef.env.getenv
import com.xebia.functional.xef.prompt.Prompt

suspend fun main() {
val spawnModelId =
getenv("OPENAI_FINE_TUNED_MODEL_ID")
?: error("Please set the OPENAI_FINE_TUNED_MODEL_ID environment variable.")

val OAI = OpenAI()
val model = OAI.spawnModel(spawnModelId, OAI.GPT_3_5_TURBO)
OpenAI.conversation {
while (true) {
print("> ")
val question = readlnOrNull() ?: break
val answer = model.promptStreaming(Prompt(question), this)
answer.collect(::print)
println()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ interface GPT4All : AutoCloseable, Chat, Completion {

val llModel = LLModel(path)

override fun copy(modelType: ModelType) =
GPT4All(url, path)

override suspend fun createCompletion(request: CompletionRequest): CompletionResult =
with(request) {
val config = LLModel.config()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@ import com.xebia.functional.xef.llm.models.usage.Usage

class HuggingFaceLocalEmbeddings(
override val modelType: ModelType,
artifact: String,
private val artifact: String,
) : Embeddings {

private val tokenizer = HuggingFaceTokenizer.newInstance("${modelType.name}/$artifact")

override val name: String = HuggingFaceLocalEmbeddings::class.java.canonicalName

override fun copy(modelType: ModelType) =
HuggingFaceLocalEmbeddings(modelType, artifact)

override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult {
val embedings = tokenizer.batchEncode(request.input)
return EmbeddingResult(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,8 @@ class GCP(projectId: String? = null, location: VertexAIRegion? = null, token: St

val defaultClient = GcpClient(config)

val CODECHAT by lazy { GcpChat(ModelType.TODO("codechat-bison@001"), defaultClient) }
val TEXT_EMBEDDING_GECKO by lazy {
GcpEmbeddings(ModelType.TODO("textembedding-gecko"), defaultClient)
}
val CODECHAT by lazy { GcpChat(this, ModelType.TODO("codechat-bison@001")) }
val TEXT_EMBEDDING_GECKO by lazy { GcpEmbeddings(this, ModelType.TODO("textembedding-gecko")) }

@JvmField val DEFAULT_CHAT = CODECHAT
@JvmField val DEFAULT_EMBEDDING = TEXT_EMBEDDING_GECKO
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class GcpClient(
)
val response =
http.post(
"https://${config.location.officialName}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/us-central1/publishers/google/models/$modelId:predict"
"https://${config.location.officialName}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location.officialName}/publishers/google/models/$modelId:predict"
) {
header("Authorization", "Bearer ${config.token}")
contentType(ContentType.Application.Json)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.xebia.functional.xef.gcp.models

import com.xebia.functional.tokenizer.ModelType
import com.xebia.functional.xef.gcp.GcpClient
import com.xebia.functional.xef.gcp.GCP
import com.xebia.functional.xef.llm.Chat
import com.xebia.functional.xef.llm.models.chat.*
import com.xebia.functional.xef.llm.models.usage.Usage
Expand All @@ -13,10 +13,14 @@ import kotlinx.uuid.UUID
import kotlinx.uuid.generateUUID

class GcpChat(
private val provider: GCP, // TODO: use context receiver
override val modelType: ModelType,
private val client: GcpClient,
) : Chat {

private val client = provider.defaultClient

override fun copy(modelType: ModelType) = GcpChat(provider, modelType)

override suspend fun createChatCompletion(
request: ChatCompletionRequest
): ChatCompletionResponse {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.xebia.functional.xef.gcp.models

import com.xebia.functional.tokenizer.ModelType
import com.xebia.functional.xef.gcp.GcpClient
import com.xebia.functional.xef.gcp.GCP
import com.xebia.functional.xef.llm.Completion
import com.xebia.functional.xef.llm.models.text.CompletionChoice
import com.xebia.functional.xef.llm.models.text.CompletionRequest
Expand All @@ -12,10 +12,14 @@ import kotlinx.uuid.UUID
import kotlinx.uuid.generateUUID

class GcpCompletion(
private val provider: GCP, // TODO: use context receiver
override val modelType: ModelType,
private val client: GcpClient,
) : Completion {

private val client = provider.defaultClient

override fun copy(modelType: ModelType) = GcpCompletion(provider, modelType)

override suspend fun createCompletion(request: CompletionRequest): CompletionResult {
val response: String =
client.promptMessage(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.xebia.functional.xef.gcp.models

import com.xebia.functional.tokenizer.ModelType
import com.xebia.functional.xef.gcp.GCP
import com.xebia.functional.xef.gcp.GcpClient
import com.xebia.functional.xef.llm.Embeddings
import com.xebia.functional.xef.llm.models.embeddings.Embedding
Expand All @@ -9,10 +10,14 @@ import com.xebia.functional.xef.llm.models.embeddings.EmbeddingResult
import com.xebia.functional.xef.llm.models.usage.Usage

class GcpEmbeddings(
private val provider: GCP, // TODO: use context receiver
override val modelType: ModelType,
private val client: GcpClient,
) : Embeddings {

private val client = provider.defaultClient

override fun copy(modelType: ModelType) = GcpEmbeddings(provider, modelType)

override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult {
fun requestToEmbedding(it: GcpClient.EmbeddingPredictions): Embedding =
Embedding(it.embeddings.values.map(Double::toFloat))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ class PGVectorStoreSpec :
})

class TestLLM(override val modelType: ModelType = ModelType.ADA) : Chat, AutoCloseable {
override fun copy(modelType: ModelType) =
TestLLM(modelType)

override fun tokensFromMessages(messages: List<Message>): Int = messages.map { calculateTokens(it) }.sum()

private fun calculateTokens(message: Message): Int = message.content.split(" ").size + 2 // 2 is the role and name
Expand Down Expand Up @@ -145,6 +148,9 @@ private fun Embeddings.Companion.mock(
}
): Embeddings =
object : Embeddings {
override fun copy(modelType: ModelType): LLM {
throw NotImplementedError()
}
override suspend fun embedDocuments(
texts: List<String>,
requestConfig: RequestConfig,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package com.xebia.functional.xef.conversation.llm.openai

import arrow.core.nonEmptyListOf
import com.aallam.openai.api.exception.InvalidRequestException
import com.aallam.openai.api.logging.LogLevel
import com.aallam.openai.api.model.ModelId
import com.aallam.openai.client.LoggingConfig
import com.aallam.openai.client.OpenAI as OpenAIClient
import com.aallam.openai.client.OpenAIHost
Expand Down Expand Up @@ -57,7 +59,7 @@ class OpenAI(internal var token: String? = null, internal var host: String? = nu
}
}

val defaultClient =
internal val defaultClient =
OpenAIClient(
host = getHost()?.let { OpenAIHost(it) } ?: OpenAIHost.OpenAI,
token = getToken(),
Expand All @@ -66,51 +68,43 @@ class OpenAI(internal var token: String? = null, internal var host: String? = nu
)
.let { autoClose(it) }

val GPT_4 by lazy { autoClose(OpenAIChat(ModelType.GPT_4, defaultClient)) }
val GPT_4 by lazy { autoClose(OpenAIChat(this, ModelType.GPT_4)) }

val GPT_4_0314 by lazy {
autoClose(OpenAIFunChat(ModelType.GPT_4_0314, defaultClient)) // legacy
autoClose(OpenAIFunChat(this, ModelType.GPT_4_0314)) // legacy
}

val GPT_4_32K by lazy { autoClose(OpenAIChat(ModelType.GPT_4_32K, defaultClient)) }
val GPT_4_32K by lazy { autoClose(OpenAIChat(this, ModelType.GPT_4_32K)) }

val GPT_3_5_TURBO by lazy { autoClose(OpenAIChat(ModelType.GPT_3_5_TURBO, defaultClient)) }
val GPT_3_5_TURBO by lazy { autoClose(OpenAIChat(this, ModelType.GPT_3_5_TURBO)) }

val GPT_3_5_TURBO_16K by lazy {
autoClose(OpenAIChat(ModelType.GPT_3_5_TURBO_16_K, defaultClient))
}
val GPT_3_5_TURBO_16K by lazy { autoClose(OpenAIChat(this, ModelType.GPT_3_5_TURBO_16_K)) }

val GPT_3_5_TURBO_FUNCTIONS by lazy {
autoClose(OpenAIFunChat(ModelType.GPT_3_5_TURBO_FUNCTIONS, defaultClient))
autoClose(OpenAIFunChat(this, ModelType.GPT_3_5_TURBO_FUNCTIONS))
}

val GPT_3_5_TURBO_0301 by lazy {
autoClose(OpenAIChat(ModelType.GPT_3_5_TURBO, defaultClient)) // legacy
autoClose(OpenAIChat(this, ModelType.GPT_3_5_TURBO)) // legacy
}

val TEXT_DAVINCI_003 by lazy {
autoClose(OpenAICompletion(ModelType.TEXT_DAVINCI_003, defaultClient))
}
val TEXT_DAVINCI_003 by lazy { autoClose(OpenAICompletion(this, ModelType.TEXT_DAVINCI_003)) }

val TEXT_DAVINCI_002 by lazy {
autoClose(OpenAICompletion(ModelType.TEXT_DAVINCI_002, defaultClient))
}
val TEXT_DAVINCI_002 by lazy { autoClose(OpenAICompletion(this, ModelType.TEXT_DAVINCI_002)) }

val TEXT_CURIE_001 by lazy {
autoClose(OpenAICompletion(ModelType.TEXT_SIMILARITY_CURIE_001, defaultClient))
autoClose(OpenAICompletion(this, ModelType.TEXT_SIMILARITY_CURIE_001))
}

val TEXT_BABBAGE_001 by lazy {
autoClose(OpenAICompletion(ModelType.TEXT_BABBAGE_001, defaultClient))
}
val TEXT_BABBAGE_001 by lazy { autoClose(OpenAICompletion(this, ModelType.TEXT_BABBAGE_001)) }

val TEXT_ADA_001 by lazy { autoClose(OpenAICompletion(ModelType.TEXT_ADA_001, defaultClient)) }
val TEXT_ADA_001 by lazy { autoClose(OpenAICompletion(this, ModelType.TEXT_ADA_001)) }

val TEXT_EMBEDDING_ADA_002 by lazy {
autoClose(OpenAIEmbeddings(ModelType.TEXT_EMBEDDING_ADA_002, defaultClient))
autoClose(OpenAIEmbeddings(this, ModelType.TEXT_EMBEDDING_ADA_002))
}

val DALLE_2 by lazy { autoClose(OpenAIImages(ModelType.GPT_3_5_TURBO, defaultClient)) }
val DALLE_2 by lazy { autoClose(OpenAIImages(this, ModelType.GPT_3_5_TURBO)) }

@JvmField val DEFAULT_CHAT = GPT_3_5_TURBO_16K

Expand All @@ -120,8 +114,8 @@ class OpenAI(internal var token: String? = null, internal var host: String? = nu

@JvmField val DEFAULT_IMAGES = DALLE_2

fun supportedModels(): List<LLM> =
listOf(
fun supportedModels(): List<LLM> = // TODO: impl of abstract provider function
listOf(
GPT_4,
GPT_4_0314,
GPT_4_32K,
Expand All @@ -138,6 +132,28 @@ class OpenAI(internal var token: String? = null, internal var host: String? = nu
DALLE_2,
)

suspend fun findModel(modelId: String): Any? { // TODO: impl of abstract provider function
val model =
try {
defaultClient.model(ModelId(modelId))
} catch (e: InvalidRequestException) {
when (e.error.detail?.code) {
"model_not_found" -> return null
else -> throw e
}
}
return ModelType.TODO(model.id.id)
}

suspend fun <T : LLM> spawnModel(
modelId: String,
baseModel: T
): T { // TODO: impl of abstract provider function
if (findModel(modelId) == null) error("model not found")
return baseModel.copy(ModelType.FineTunedModel(modelId, baseModel = baseModel.modelType)) as? T
?: error("${baseModel::class} does not follow contract to return the most specific type")
}

companion object {

@JvmField val FromEnvironment: OpenAI = OpenAI()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import com.aallam.openai.api.chat.ChatChoice
import com.aallam.openai.api.chat.ChatMessage
import com.aallam.openai.api.chat.chatCompletionRequest
import com.aallam.openai.api.model.ModelId
import com.aallam.openai.client.OpenAI
import com.xebia.functional.tokenizer.ModelType
import com.xebia.functional.xef.conversation.llm.openai.OpenAI
import com.xebia.functional.xef.conversation.llm.openai.toInternal
import com.xebia.functional.xef.conversation.llm.openai.toOpenAI
import com.xebia.functional.xef.llm.Chat
Expand All @@ -14,10 +14,14 @@ import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.map

class OpenAIChat(
private val provider: OpenAI, // TODO: use context receiver
override val modelType: ModelType,
private val client: OpenAI,
) : Chat {

private val client = provider.defaultClient

override fun copy(modelType: ModelType) = OpenAIChat(provider, modelType)

override suspend fun createChatCompletion(
request: ChatCompletionRequest
): ChatCompletionResponse {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,23 @@ import com.aallam.openai.api.LegacyOpenAI
import com.aallam.openai.api.completion.Choice
import com.aallam.openai.api.completion.completionRequest
import com.aallam.openai.api.model.ModelId
import com.aallam.openai.client.OpenAI
import com.xebia.functional.tokenizer.ModelType
import com.xebia.functional.xef.conversation.llm.openai.OpenAI
import com.xebia.functional.xef.conversation.llm.openai.toInternal
import com.xebia.functional.xef.llm.Completion
import com.xebia.functional.xef.llm.models.text.CompletionChoice
import com.xebia.functional.xef.llm.models.text.CompletionRequest
import com.xebia.functional.xef.llm.models.text.CompletionResult

class OpenAICompletion(
private val provider: OpenAI, // TODO: use context receiver
override val modelType: ModelType,
private val client: OpenAI,
) : Completion {

private val client = provider.defaultClient

override fun copy(modelType: ModelType) = OpenAICompletion(provider, modelType)

@OptIn(LegacyOpenAI::class)
override suspend fun createCompletion(request: CompletionRequest): CompletionResult {
fun toInternal(it: Choice): CompletionChoice =
Expand Down
Loading

0 comments on commit 85fa396

Please sign in to comment.