diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AI.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AI.kt
index 73ae511ad..9e9266a54 100644
--- a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AI.kt
+++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AI.kt
@@ -4,17 +4,11 @@ import arrow.core.Either
import arrow.core.left
import arrow.core.right
import com.xebia.functional.xef.AIError
-import com.xebia.functional.xef.embeddings.Embeddings
import com.xebia.functional.xef.embeddings.OpenAIEmbeddings
import com.xebia.functional.xef.env.OpenAIConfig
-import com.xebia.functional.xef.llm.openai.KtorOpenAIClient
-import com.xebia.functional.xef.llm.openai.MockOpenAIClient
-import com.xebia.functional.xef.llm.openai.OpenAIClient
-import com.xebia.functional.xef.llm.openai.simpleMockAIClient
-import com.xebia.functional.xef.vectorstores.CombinedVectorStore
+import com.xebia.functional.xef.llm.openai.*
import com.xebia.functional.xef.vectorstores.LocalVectorStore
import com.xebia.functional.xef.vectorstores.VectorStore
-import kotlin.jvm.JvmName
import kotlin.time.ExperimentalTime
@DslMarker annotation class AiDsl
@@ -42,7 +36,7 @@ suspend inline fun AI.getOrElse(crossinline orElse: suspend (AIError) ->
AIScope(this) { orElse(it) }
@OptIn(ExperimentalTime::class, ExperimentalStdlibApi::class)
-suspend fun AIScope(block: suspend AIScope.() -> A, orElse: suspend (AIError) -> A): A =
+suspend fun AIScope(block: AI, orElse: suspend (AIError) -> A): A =
try {
val openAIConfig = OpenAIConfig()
KtorOpenAIClient(openAIConfig).use { openAiClient ->
@@ -106,81 +100,3 @@ suspend fun AI.mock(mockAI: (String) -> String): Either =
* throwing.
*/
suspend inline fun AI.getOrThrow(): A = getOrElse { throw it }
-
-/**
- * The [AIScope] is the context in which [AI] values are run. It encapsulates all the dependencies
- * required to run [AI] values, and provides convenient syntax for writing [AI] based programs.
- */
-class AIScope(
- val openAIClient: OpenAIClient,
- val context: VectorStore,
- val embeddings: Embeddings
-) {
-
- /**
- * Allows invoking [AI] values in the context of this [AIScope].
- *
- * ```kotlin
- * data class CovidNews(val title: String, val content: String)
- * val covidNewsToday = ai {
- * val now = LocalDateTime.now()
- * agent(search("$now covid-19 News")) {
- * prompt("write a paragraph of about 300 words about the latest news on covid-19 on $now")
- * }
- * }
- *
- * data class BreakingNews(val title: String, val content: String, val date: String)
- *
- * fun breakingNews(date: LocalDateTime): AI = ai {
- * agent(search("$date Breaking News")) {
- * prompt("Summarize all breaking news that happened on ${now.minusDays(it)} in about 300 words")
- * }
- * }
- *
- * suspend fun AIScope.breakingNewsLastWeek(): List {
- * val now = LocalDateTime.now()
- * return (0..7).parMap { breakingNews(now.minusDays(it)).invoke() }
- * }
- *
- * fun news(): AI> = ai {
- * val covidNews = parZip(
- * { covidNewsToday() },
- * { breakingNewsLastWeek() }
- * ) { covidNews, breakingNews -> listOf(covidNews) + breakingNews }
- * }
- * ```
- */
- @AiDsl @JvmName("invokeAI") suspend operator fun AI.invoke(): A = invoke(this@AIScope)
-
- @AiDsl
- suspend fun extendContext(vararg docs: String) {
- context.addTexts(docs.toList())
- }
-
- /**
- * Creates a nested scope that combines the provided [store] with the outer _store_. This is done
- * using [CombinedVectorStore].
- *
- * **Note:** if the implementation of [VectorStore] is relying on resources you're manually
- * responsible for closing any potential resources.
- */
- @AiDsl
- suspend fun contextScope(store: VectorStore, block: AI): A =
- AIScope(
- this@AIScope.openAIClient,
- CombinedVectorStore(store, this@AIScope.context),
- this@AIScope.embeddings
- )
- .block()
-
- @AiDsl
- suspend fun contextScope(block: AI): A = contextScope(LocalVectorStore(embeddings), block)
-
- /** Add new [docs] to the [context], and then executes the [block]. */
- @AiDsl
- @JvmName("contextScopeWithDocs")
- suspend fun contextScope(docs: List, block: AI): A = contextScope {
- extendContext(*docs.toTypedArray())
- block(this)
- }
-}
diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AIScope.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AIScope.kt
new file mode 100644
index 000000000..d76f03205
--- /dev/null
+++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AIScope.kt
@@ -0,0 +1,501 @@
+package com.xebia.functional.xef.auto
+
+import arrow.core.nonFatalOrThrow
+import arrow.core.raise.catch
+import com.xebia.functional.tokenizer.Encoding
+import com.xebia.functional.tokenizer.ModelType
+import com.xebia.functional.tokenizer.truncateText
+import com.xebia.functional.xef.AIError
+import com.xebia.functional.xef.embeddings.Embeddings
+import com.xebia.functional.xef.llm.openai.*
+import com.xebia.functional.xef.llm.openai.functions.CFunction
+import com.xebia.functional.xef.llm.openai.images.ImagesGenerationRequest
+import com.xebia.functional.xef.llm.openai.images.ImagesGenerationResponse
+import com.xebia.functional.xef.prompt.Prompt
+import com.xebia.functional.xef.vectorstores.CombinedVectorStore
+import com.xebia.functional.xef.vectorstores.LocalVectorStore
+import com.xebia.functional.xef.vectorstores.VectorStore
+import io.github.oshai.kotlinlogging.KLogger
+import io.github.oshai.kotlinlogging.KotlinLogging
+import kotlin.jvm.JvmName
+
+/**
+ * The [AIScope] is the context in which [AI] values are run. It encapsulates all the dependencies
+ * required to run [AI] values, and provides convenient syntax for writing [AI] based programs.
+ */
+class AIScope(
+ val openAIClient: OpenAIClient,
+ val context: VectorStore,
+ val embeddings: Embeddings,
+ val logger: KLogger = KotlinLogging.logger {}
+) {
+
+ /**
+ * Allows invoking [AI] values in the context of this [AIScope].
+ *
+ * ```kotlin
+ * data class CovidNews(val title: String, val content: String)
+ * val covidNewsToday = ai {
+ * val now = LocalDateTime.now()
+ * agent(search("$now covid-19 News")) {
+ * prompt("write a paragraph of about 300 words about the latest news on covid-19 on $now")
+ * }
+ * }
+ *
+ * data class BreakingNews(val title: String, val content: String, val date: String)
+ *
+ * fun breakingNews(date: LocalDateTime): AI = ai {
+ * agent(search("$date Breaking News")) {
+ * prompt("Summarize all breaking news that happened on ${now.minusDays(it)} in about 300 words")
+ * }
+ * }
+ *
+ * suspend fun AIScope.breakingNewsLastWeek(): List {
+ * val now = LocalDateTime.now()
+ * return (0..7).parMap { breakingNews(now.minusDays(it)).invoke() }
+ * }
+ *
+ * fun news(): AI> = ai {
+ * val covidNews = parZip(
+ * { covidNewsToday() },
+ * { breakingNewsLastWeek() }
+ * ) { covidNews, breakingNews -> listOf(covidNews) + breakingNews }
+ * }
+ * ```
+ */
+ @AiDsl @JvmName("invokeAI") suspend operator fun AI.invoke(): A = invoke(this@AIScope)
+
+ @AiDsl
+ suspend fun extendContext(vararg docs: String) {
+ context.addTexts(docs.toList())
+ }
+
+ /**
+ * Creates a nested scope that combines the provided [store] with the outer _store_. This is done
+ * using [CombinedVectorStore].
+ *
+ * **Note:** if the implementation of [VectorStore] is relying on resources you're manually
+ * responsible for closing any potential resources.
+ */
+ @AiDsl
+ suspend fun contextScope(store: VectorStore, block: AI): A =
+ AIScope(
+ this@AIScope.openAIClient,
+ CombinedVectorStore(store, this@AIScope.context),
+ this@AIScope.embeddings
+ )
+ .block()
+
+ @AiDsl
+ suspend fun contextScope(block: AI): A = contextScope(LocalVectorStore(embeddings), block)
+
+ /** Add new [docs] to the [context], and then executes the [block]. */
+ @AiDsl
+ @JvmName("contextScopeWithDocs")
+ suspend fun contextScope(docs: List, block: AI): A = contextScope {
+ extendContext(*docs.toTypedArray())
+ block(this)
+ }
+
+ @AiDsl
+ @JvmName("promptWithSerializer")
+ suspend fun prompt(
+ prompt: Prompt,
+ functions: List,
+ serializer: (json: String) -> A,
+ maxDeserializationAttempts: Int = 5,
+ model: LLMModel = LLMModel.GPT_3_5_TURBO_FUNCTIONS,
+ user: String = "testing",
+ echo: Boolean = false,
+ n: Int = 1,
+ temperature: Double = 0.0,
+ bringFromContext: Int = 10,
+ minResponseTokens: Int = 500,
+ ): A {
+ return tryDeserialize(serializer, maxDeserializationAttempts) {
+ promptMessage(
+ prompt = prompt,
+ model = model,
+ functions = functions,
+ user = user,
+ echo = echo,
+ n = n,
+ temperature = temperature,
+ bringFromContext = bringFromContext,
+ minResponseTokens = minResponseTokens
+ )
+ }
+ }
+
+ suspend fun AIScope.tryDeserialize(
+ serializer: (json: String) -> A,
+ maxDeserializationAttempts: Int,
+ agent: AI>
+ ): A {
+ val logger = KotlinLogging.logger {}
+ (0 until maxDeserializationAttempts).forEach { currentAttempts ->
+ val result = agent().firstOrNull() ?: throw AIError.NoResponse()
+ catch({
+ return@tryDeserialize serializer(result)
+ }) { e: Throwable ->
+ logger.error(e) { "Error deserializing response: $result\n${e.message}" }
+ if (currentAttempts == maxDeserializationAttempts)
+ throw AIError.JsonParsing(result, maxDeserializationAttempts, e.nonFatalOrThrow())
+ // TODO else log attempt ?
+ }
+ }
+ throw AIError.NoResponse()
+ }
+
+ @AiDsl
+ suspend fun promptMessage(
+ question: String,
+ model: LLMModel = LLMModel.GPT_3_5_TURBO,
+ functions: List = emptyList(),
+ user: String = "testing",
+ echo: Boolean = false,
+ n: Int = 1,
+ temperature: Double = 0.0,
+ bringFromContext: Int = 10,
+ minResponseTokens: Int = 500
+ ): List =
+ promptMessage(
+ Prompt(question),
+ model,
+ functions,
+ user,
+ echo,
+ n,
+ temperature,
+ bringFromContext,
+ minResponseTokens
+ )
+
+ @AiDsl
+ suspend fun promptMessage(
+ prompt: Prompt,
+ model: LLMModel = LLMModel.GPT_3_5_TURBO,
+ functions: List = emptyList(),
+ user: String = "testing",
+ echo: Boolean = false,
+ n: Int = 1,
+ temperature: Double = 0.0,
+ bringFromContext: Int = 10,
+ minResponseTokens: Int
+ ): List {
+ return when (model.kind) {
+ LLMModel.Kind.Completion ->
+ callCompletionEndpoint(
+ prompt.message,
+ model,
+ user,
+ echo,
+ n,
+ temperature,
+ bringFromContext,
+ minResponseTokens
+ )
+ LLMModel.Kind.Chat ->
+ callChatEndpoint(
+ prompt.message,
+ model,
+ user,
+ n,
+ temperature,
+ bringFromContext,
+ minResponseTokens
+ )
+ LLMModel.Kind.ChatWithFunctions ->
+ callChatEndpointWithFunctionsSupport(
+ prompt.message,
+ model,
+ functions,
+ user,
+ n,
+ temperature,
+ bringFromContext,
+ minResponseTokens
+ )
+ .map { it.arguments }
+ }
+ }
+
+ private suspend fun callCompletionEndpoint(
+ prompt: String,
+ model: LLMModel,
+ user: String = "testing",
+ echo: Boolean = false,
+ n: Int = 1,
+ temperature: Double = 0.0,
+ bringFromContext: Int,
+ minResponseTokens: Int
+ ): List {
+ val promptWithContext: String =
+ promptWithContext(prompt, bringFromContext, model.modelType, minResponseTokens)
+
+ val maxTokens: Int = checkTotalLeftTokens(model.modelType, "", promptWithContext)
+
+ val request =
+ CompletionRequest(
+ model = model.name,
+ user = user,
+ prompt = promptWithContext,
+ echo = echo,
+ n = n,
+ temperature = temperature,
+ maxTokens = maxTokens
+ )
+ return openAIClient.createCompletion(request).choices.map { it.text }
+ }
+
+ private suspend fun callChatEndpoint(
+ prompt: String,
+ model: LLMModel,
+ user: String = "testing",
+ n: Int = 1,
+ temperature: Double = 0.0,
+ bringFromContext: Int,
+ minResponseTokens: Int
+ ): List {
+ val role: String = Role.system.name
+ val promptWithContext: String =
+ promptWithContext(prompt, bringFromContext, model.modelType, minResponseTokens)
+ val messages: List = listOf(Message(role, promptWithContext))
+ val maxTokens: Int = checkTotalLeftChatTokens(messages, model)
+ val request =
+ ChatCompletionRequest(
+ model = model.name,
+ user = user,
+ messages = messages,
+ n = n,
+ temperature = temperature,
+ maxTokens = maxTokens
+ )
+ return openAIClient.createChatCompletion(request).choices.map { it.message.content }
+ }
+
+ private suspend fun callChatEndpointWithFunctionsSupport(
+ prompt: String,
+ model: LLMModel,
+ functions: List,
+ user: String = "function",
+ n: Int = 1,
+ temperature: Double = 0.0,
+ bringFromContext: Int,
+ minResponseTokens: Int
+ ): List {
+ val role: String = Role.user.name
+ val firstFnName: String? = functions.firstOrNull()?.name
+ val promptWithContext: String =
+ promptWithContext(prompt, bringFromContext, model.modelType, minResponseTokens)
+ val messages: List = listOf(Message(role, promptWithContext))
+ val maxTokens: Int = checkTotalLeftChatTokens(messages, model)
+ val request =
+ ChatCompletionRequestWithFunctions(
+ model = model.name,
+ user = user,
+ messages = messages,
+ n = n,
+ temperature = temperature,
+ maxTokens = maxTokens,
+ functions = functions,
+ functionCall = mapOf("name" to (firstFnName ?: ""))
+ )
+ return openAIClient.createChatCompletionWithFunctions(request).choices.map {
+ it.message.functionCall
+ }
+ }
+
+ private suspend fun promptWithContext(
+ prompt: String,
+ bringFromContext: Int,
+ modelType: ModelType,
+ minResponseTokens: Int
+ ): String {
+ val ctxInfo: List = context.similaritySearch(prompt, bringFromContext)
+ return createPromptWithContextAwareOfTokens(
+ ctxInfo = ctxInfo,
+ modelType = modelType,
+ prompt = prompt,
+ minResponseTokens = minResponseTokens
+ )
+ }
+
+ private fun createPromptWithContextAwareOfTokens(
+ ctxInfo: List,
+ modelType: ModelType,
+ prompt: String,
+ minResponseTokens: Int,
+ ): String {
+ val maxContextLength: Int = modelType.maxContextLength
+ val promptTokens: Int = modelType.encoding.countTokens(prompt)
+ val remainingTokens: Int = maxContextLength - promptTokens - minResponseTokens
+
+ return if (ctxInfo.isNotEmpty() && remainingTokens > minResponseTokens) {
+ val ctx: String = ctxInfo.joinToString("\n")
+
+ if (promptTokens >= maxContextLength) {
+ throw AIError.PromptExceedsMaxTokenLength(prompt, promptTokens, maxContextLength)
+ }
+ // truncate the context if it's too long based on the max tokens calculated considering the
+ // existing prompt tokens
+ // alternatively we could summarize the context, but that's not implemented yet
+ val ctxTruncated: String = modelType.encoding.truncateText(ctx, remainingTokens)
+
+ """|```Context
+ |${ctxTruncated}
+ |```
+ |The context is related to the question try to answer the `goal` as best as you can
+ |or provide information about the found content
+ |```goal
+ |${prompt}
+ |```
+ |ANSWER:
+ |"""
+ .trimMargin()
+ } else prompt
+ }
+
+ private fun checkTotalLeftTokens(
+ modelType: ModelType,
+ role: String,
+ promptWithContext: String
+ ): Int =
+ with(modelType) {
+ val roleTokens: Int = encoding.countTokens(role)
+ val padding = 20 // reserve 20 tokens for additional symbols around the context
+ val promptTokens: Int = encoding.countTokens(promptWithContext)
+ val takenTokens: Int = roleTokens + promptTokens + padding
+ val totalLeftTokens: Int = maxContextLength - takenTokens
+ if (totalLeftTokens < 0) {
+ throw AIError.PromptExceedsMaxTokenLength(promptWithContext, takenTokens, maxContextLength)
+ }
+ logger.debug {
+ "Tokens -- used: $takenTokens, model max: $maxContextLength, left: $totalLeftTokens"
+ }
+ totalLeftTokens
+ }
+
+ private fun checkTotalLeftChatTokens(messages: List, model: LLMModel): Int {
+ val maxContextLength: Int = model.modelType.maxContextLength
+ val messagesTokens: Int = tokensFromMessages(messages, model)
+ val totalLeftTokens: Int = maxContextLength - messagesTokens
+ if (totalLeftTokens < 0) {
+ throw AIError.MessagesExceedMaxTokenLength(messages, messagesTokens, maxContextLength)
+ }
+ logger.debug {
+ "Tokens -- used: $messagesTokens, model max: $maxContextLength, left: $totalLeftTokens"
+ }
+ return totalLeftTokens
+ }
+
+ private fun tokensFromMessages(messages: List, model: LLMModel): Int =
+ when (model) {
+ LLMModel.GPT_3_5_TURBO -> {
+ val paddingTokens = 5 // otherwise if the model changes, it might later fail
+ val fallbackModel: LLMModel = LLMModel.GPT_3_5_TURBO_0301
+ logger.debug {
+ "Warning: ${model.name} may change over time. " +
+ "Returning messages num tokens assuming ${fallbackModel.name} + $paddingTokens padding tokens."
+ }
+ tokensFromMessages(messages, fallbackModel) + paddingTokens
+ }
+ LLMModel.GPT_4,
+ LLMModel.GPT_4_32K -> {
+ val paddingTokens = 5 // otherwise if the model changes, it might later fail
+ val fallbackModel: LLMModel = LLMModel.GPT_4_0314
+ logger.debug {
+ "Warning: ${model.name} may change over time. " +
+ "Returning messages num tokens assuming ${fallbackModel.name} + $paddingTokens padding tokens."
+ }
+ tokensFromMessages(messages, fallbackModel) + paddingTokens
+ }
+ LLMModel.GPT_3_5_TURBO_0301 ->
+ model.modelType.encoding.countTokensFromMessages(
+ messages,
+ tokensPerMessage = 4,
+ tokensPerName = 0
+ )
+ LLMModel.GPT_4_0314 ->
+ model.modelType.encoding.countTokensFromMessages(
+ messages,
+ tokensPerMessage = 3,
+ tokensPerName = 2
+ )
+ else -> {
+ val paddingTokens = 20
+ val fallbackModel: LLMModel = LLMModel.GPT_3_5_TURBO_0301
+ logger.debug {
+ "Warning: calculation of tokens is partially supported for ${model.name} . " +
+ "Returning messages num tokens assuming ${fallbackModel.name} + $paddingTokens padding tokens."
+ }
+ tokensFromMessages(messages, fallbackModel) + paddingTokens
+ }
+ }
+
+ private fun Encoding.countTokensFromMessages(
+ messages: List,
+ tokensPerMessage: Int,
+ tokensPerName: Int
+ ): Int =
+ messages.sumOf { message ->
+ countTokens(message.role) +
+ countTokens(message.content) +
+ tokensPerMessage +
+ (message.name?.let { tokensPerName } ?: 0)
+ } + 3
+
+ /**
+ * Run a [prompt] describes the images you want to generate within the context of [AIScope].
+ * Returns a [ImagesGenerationResponse] containing time and urls with images generated.
+ *
+ * @param prompt a [Prompt] describing the images you want to generate.
+ * @param numberImages number of images to generate.
+ * @param size the size of the images to generate.
+ */
+ suspend fun images(
+ prompt: String,
+ user: String = "testing",
+ numberImages: Int = 1,
+ size: String = "1024x1024",
+ bringFromContext: Int = 10
+ ): ImagesGenerationResponse = images(Prompt(prompt), user, numberImages, size, bringFromContext)
+
+ /**
+ * Run a [prompt] describes the images you want to generate within the context of [AIScope].
+ * Returns a [ImagesGenerationResponse] containing time and urls with images generated.
+ *
+ * @param prompt a [Prompt] describing the images you want to generate.
+ * @param numberImages number of images to generate.
+ * @param size the size of the images to generate.
+ */
+ suspend fun images(
+ prompt: Prompt,
+ user: String = "testing",
+ numberImages: Int = 1,
+ size: String = "1024x1024",
+ bringFromContext: Int = 10
+ ): ImagesGenerationResponse {
+ val ctxInfo = context.similaritySearch(prompt.message, bringFromContext)
+ val promptWithContext =
+ if (ctxInfo.isNotEmpty()) {
+ """|Instructions: Use the [Information] below delimited by 3 backticks to accomplish
+ |the [Objective] at the end of the prompt.
+ |Try to match the data returned in the [Objective] with this [Information] as best as you can.
+ |[Information]:
+ |```
+ |${ctxInfo.joinToString("\n")}
+ |```
+ |$prompt"""
+ .trimMargin()
+ } else prompt.message
+ val request =
+ ImagesGenerationRequest(
+ prompt = promptWithContext,
+ numberImages = numberImages,
+ size = size,
+ user = user
+ )
+ return openAIClient.createImages(request)
+ }
+}
diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/DeserializerLLMAgent.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/DeserializerLLMAgent.kt
deleted file mode 100644
index e0404fe60..000000000
--- a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/DeserializerLLMAgent.kt
+++ /dev/null
@@ -1,64 +0,0 @@
-@file:JvmMultifileClass
-@file:JvmName("Agent")
-
-package com.xebia.functional.xef.auto
-
-import arrow.core.nonFatalOrThrow
-import arrow.core.raise.catch
-import com.xebia.functional.xef.AIError
-import com.xebia.functional.xef.llm.openai.LLMModel
-import com.xebia.functional.xef.llm.openai.functions.CFunction
-import com.xebia.functional.xef.prompt.Prompt
-import io.github.oshai.kotlinlogging.KotlinLogging
-import kotlin.jvm.JvmMultifileClass
-import kotlin.jvm.JvmName
-
-@AiDsl
-@JvmName("promptWithSerializer")
-suspend fun AIScope.prompt(
- prompt: Prompt,
- functions: List,
- serializer: (json: String) -> A,
- maxDeserializationAttempts: Int = 5,
- model: LLMModel = LLMModel.GPT_3_5_TURBO_FUNCTIONS,
- user: String = "testing",
- echo: Boolean = false,
- n: Int = 1,
- temperature: Double = 0.0,
- bringFromContext: Int = 10,
- minResponseTokens: Int = 500,
-): A {
- return tryDeserialize(serializer, maxDeserializationAttempts) {
- promptMessage(
- prompt = prompt,
- model = model,
- functions = functions,
- user = user,
- echo = echo,
- n = n,
- temperature = temperature,
- bringFromContext = bringFromContext,
- minResponseTokens = minResponseTokens
- )
- }
-}
-
-suspend fun AIScope.tryDeserialize(
- serializer: (json: String) -> A,
- maxDeserializationAttempts: Int,
- agent: AI>
-): A {
- val logger = KotlinLogging.logger {}
- (0 until maxDeserializationAttempts).forEach { currentAttempts ->
- val result = agent().firstOrNull() ?: throw AIError.NoResponse()
- catch({
- return@tryDeserialize serializer(result)
- }) { e: Throwable ->
- logger.error(e) { "Error deserializing response: $result\n${e.message}" }
- if (currentAttempts == maxDeserializationAttempts)
- throw AIError.JsonParsing(result, maxDeserializationAttempts, e.nonFatalOrThrow())
- // TODO else log attempt ?
- }
- }
- throw AIError.NoResponse()
-}
diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/ImageGenerationAgent.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/ImageGenerationAgent.kt
deleted file mode 100644
index 6a18ec5de..000000000
--- a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/ImageGenerationAgent.kt
+++ /dev/null
@@ -1,64 +0,0 @@
-@file:JvmMultifileClass
-@file:JvmName("Agent")
-
-package com.xebia.functional.xef.auto
-
-import com.xebia.functional.xef.llm.openai.images.ImagesGenerationRequest
-import com.xebia.functional.xef.llm.openai.images.ImagesGenerationResponse
-import com.xebia.functional.xef.prompt.Prompt
-import kotlin.jvm.JvmMultifileClass
-import kotlin.jvm.JvmName
-
-/**
- * Run a [prompt] describes the images you want to generate within the context of [AIScope]. Returns
- * a [ImagesGenerationResponse] containing time and urls with images generated.
- *
- * @param prompt a [Prompt] describing the images you want to generate.
- * @param numberImages number of images to generate.
- * @param size the size of the images to generate.
- */
-suspend fun AIScope.images(
- prompt: String,
- user: String = "testing",
- numberImages: Int = 1,
- size: String = "1024x1024",
- bringFromContext: Int = 10
-): ImagesGenerationResponse = images(Prompt(prompt), user, numberImages, size, bringFromContext)
-
-/**
- * Run a [prompt] describes the images you want to generate within the context of [AIScope]. Returns
- * a [ImagesGenerationResponse] containing time and urls with images generated.
- *
- * @param prompt a [Prompt] describing the images you want to generate.
- * @param numberImages number of images to generate.
- * @param size the size of the images to generate.
- */
-suspend fun AIScope.images(
- prompt: Prompt,
- user: String = "testing",
- numberImages: Int = 1,
- size: String = "1024x1024",
- bringFromContext: Int = 10
-): ImagesGenerationResponse {
- val ctxInfo = context.similaritySearch(prompt.message, bringFromContext)
- val promptWithContext =
- if (ctxInfo.isNotEmpty()) {
- """|Instructions: Use the [Information] below delimited by 3 backticks to accomplish
- |the [Objective] at the end of the prompt.
- |Try to match the data returned in the [Objective] with this [Information] as best as you can.
- |[Information]:
- |```
- |${ctxInfo.joinToString("\n")}
- |```
- |$prompt"""
- .trimMargin()
- } else prompt.message
- val request =
- ImagesGenerationRequest(
- prompt = promptWithContext,
- numberImages = numberImages,
- size = size,
- user = user
- )
- return openAIClient.createImages(request)
-}
diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/LLMAgent.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/LLMAgent.kt
deleted file mode 100644
index f0d56b0bf..000000000
--- a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/LLMAgent.kt
+++ /dev/null
@@ -1,326 +0,0 @@
-@file:JvmMultifileClass
-@file:JvmName("Agent")
-
-package com.xebia.functional.xef.auto
-
-import com.xebia.functional.tokenizer.Encoding
-import com.xebia.functional.tokenizer.ModelType
-import com.xebia.functional.tokenizer.truncateText
-import com.xebia.functional.xef.AIError
-import com.xebia.functional.xef.llm.openai.*
-import com.xebia.functional.xef.llm.openai.functions.CFunction
-import com.xebia.functional.xef.prompt.Prompt
-import io.github.oshai.kotlinlogging.KLogger
-import io.github.oshai.kotlinlogging.KotlinLogging
-import kotlin.jvm.JvmMultifileClass
-import kotlin.jvm.JvmName
-
-private val logger: KLogger by lazy { KotlinLogging.logger {} }
-
-@AiDsl
-suspend fun AIScope.promptMessage(
- question: String,
- model: LLMModel = LLMModel.GPT_3_5_TURBO,
- functions: List = emptyList(),
- user: String = "testing",
- echo: Boolean = false,
- n: Int = 1,
- temperature: Double = 0.0,
- bringFromContext: Int = 10,
- minResponseTokens: Int = 500
-): List =
- promptMessage(
- Prompt(question),
- model,
- functions,
- user,
- echo,
- n,
- temperature,
- bringFromContext,
- minResponseTokens
- )
-
-@AiDsl
-suspend fun AIScope.promptMessage(
- prompt: Prompt,
- model: LLMModel = LLMModel.GPT_3_5_TURBO,
- functions: List = emptyList(),
- user: String = "testing",
- echo: Boolean = false,
- n: Int = 1,
- temperature: Double = 0.0,
- bringFromContext: Int = 10,
- minResponseTokens: Int
-): List {
- return when (model.kind) {
- LLMModel.Kind.Completion ->
- callCompletionEndpoint(
- prompt.message,
- model,
- user,
- echo,
- n,
- temperature,
- bringFromContext,
- minResponseTokens
- )
- LLMModel.Kind.Chat ->
- callChatEndpoint(
- prompt.message,
- model,
- user,
- n,
- temperature,
- bringFromContext,
- minResponseTokens
- )
- LLMModel.Kind.ChatWithFunctions ->
- callChatEndpointWithFunctionsSupport(
- prompt.message,
- model,
- functions,
- user,
- n,
- temperature,
- bringFromContext,
- minResponseTokens
- )
- .map { it.arguments }
- }
-}
-
-private fun createPromptWithContextAwareOfTokens(
- ctxInfo: List,
- modelType: ModelType,
- prompt: String,
- minResponseTokens: Int,
-): String {
- val maxContextLength: Int = modelType.maxContextLength
- val promptTokens: Int = modelType.encoding.countTokens(prompt)
- val remainingTokens: Int = maxContextLength - promptTokens - minResponseTokens
-
- return if (ctxInfo.isNotEmpty() && remainingTokens > minResponseTokens) {
- val ctx: String = ctxInfo.joinToString("\n")
-
- if (promptTokens >= maxContextLength) {
- throw AIError.PromptExceedsMaxTokenLength(prompt, promptTokens, maxContextLength)
- }
- // truncate the context if it's too long based on the max tokens calculated considering the
- // existing prompt tokens
- // alternatively we could summarize the context, but that's not implemented yet
- val ctxTruncated: String = modelType.encoding.truncateText(ctx, remainingTokens)
-
- """|```Context
- |${ctxTruncated}
- |```
- |The context is related to the question try to answer the `goal` as best as you can
- |or provide information about the found content
- |```goal
- |${prompt}
- |```
- |ANSWER:
- |"""
- .trimMargin()
- } else prompt
-}
-
-private suspend fun AIScope.callCompletionEndpoint(
- prompt: String,
- model: LLMModel,
- user: String = "testing",
- echo: Boolean = false,
- n: Int = 1,
- temperature: Double = 0.0,
- bringFromContext: Int,
- minResponseTokens: Int
-): List {
- val promptWithContext: String =
- promptWithContext(prompt, bringFromContext, model.modelType, minResponseTokens)
-
- val maxTokens: Int = checkTotalLeftTokens(model.modelType, "", promptWithContext)
-
- val request =
- CompletionRequest(
- model = model.name,
- user = user,
- prompt = promptWithContext,
- echo = echo,
- n = n,
- temperature = temperature,
- maxTokens = maxTokens
- )
- return openAIClient.createCompletion(request).choices.map { it.text }
-}
-
-private suspend fun AIScope.callChatEndpoint(
- prompt: String,
- model: LLMModel,
- user: String = "testing",
- n: Int = 1,
- temperature: Double = 0.0,
- bringFromContext: Int,
- minResponseTokens: Int
-): List {
- val role: String = Role.system.name
- val promptWithContext: String =
- promptWithContext(prompt, bringFromContext, model.modelType, minResponseTokens)
- val messages: List = listOf(Message(role, promptWithContext))
- val maxTokens: Int = checkTotalLeftChatTokens(messages, model)
- val request =
- ChatCompletionRequest(
- model = model.name,
- user = user,
- messages = messages,
- n = n,
- temperature = temperature,
- maxTokens = maxTokens
- )
- return openAIClient.createChatCompletion(request).choices.map { it.message.content }
-}
-
-private suspend fun AIScope.callChatEndpointWithFunctionsSupport(
- prompt: String,
- model: LLMModel,
- functions: List,
- user: String = "function",
- n: Int = 1,
- temperature: Double = 0.0,
- bringFromContext: Int,
- minResponseTokens: Int
-): List {
- val role: String = Role.user.name
- val firstFnName: String? = functions.firstOrNull()?.name
- val promptWithContext: String =
- promptWithContext(prompt, bringFromContext, model.modelType, minResponseTokens)
- val messages: List = listOf(Message(role, promptWithContext))
- val maxTokens: Int = checkTotalLeftChatTokens(messages, model)
- val request =
- ChatCompletionRequestWithFunctions(
- model = model.name,
- user = user,
- messages = messages,
- n = n,
- temperature = temperature,
- maxTokens = maxTokens,
- functions = functions,
- functionCall = mapOf("name" to (firstFnName ?: ""))
- )
- return openAIClient.createChatCompletionWithFunctions(request).choices.map {
- it.message.functionCall
- }
-}
-
-private suspend fun AIScope.promptWithContext(
- prompt: String,
- bringFromContext: Int,
- modelType: ModelType,
- minResponseTokens: Int
-): String {
- val ctxInfo: List = context.similaritySearch(prompt, bringFromContext)
- return createPromptWithContextAwareOfTokens(
- ctxInfo = ctxInfo,
- modelType = modelType,
- prompt = prompt,
- minResponseTokens = minResponseTokens
- )
-}
-
-private fun checkTotalLeftTokens(
- modelType: ModelType,
- role: String,
- promptWithContext: String
-): Int =
- with(modelType) {
- val roleTokens: Int = encoding.countTokens(role)
- val padding = 20 // reserve 20 tokens for additional symbols around the context
- val promptTokens: Int = encoding.countTokens(promptWithContext)
- val takenTokens: Int = roleTokens + promptTokens + padding
- val totalLeftTokens: Int = maxContextLength - takenTokens
- if (totalLeftTokens < 0) {
- throw AIError.PromptExceedsMaxTokenLength(promptWithContext, takenTokens, maxContextLength)
- }
- logger.debug {
- "Tokens -- used: $takenTokens, model max: $maxContextLength, left: $totalLeftTokens"
- }
- totalLeftTokens
- }
-
-private fun AIScope.checkTotalLeftChatTokens(messages: List, model: LLMModel): Int {
- val maxContextLength: Int = model.modelType.maxContextLength
- val messagesTokens: Int = tokensFromMessages(messages, model)
- val totalLeftTokens: Int = maxContextLength - messagesTokens
- if (totalLeftTokens < 0) {
- throw AIError.MessagesExceedMaxTokenLength(messages, messagesTokens, maxContextLength)
- }
- logger.debug {
- "Tokens -- used: $messagesTokens, model max: $maxContextLength, left: $totalLeftTokens"
- }
- return totalLeftTokens
-}
-
-private fun tokensFromMessages(messages: List, model: LLMModel): Int =
- when (model) {
- LLMModel.GPT_3_5_TURBO_FUNCTIONS -> {
- val paddingTokens = 200
- // TODO 200 tokens reserved for function calls, what is a better way to count these?
- val fallbackModel: LLMModel = LLMModel.GPT_3_5_TURBO
- logger.debug {
- "Warning: ${model.name} may change over time. " +
- "Returning messages num tokens assuming ${fallbackModel.name} + $paddingTokens padding tokens."
- }
- tokensFromMessages(messages, fallbackModel) + paddingTokens
- }
- LLMModel.GPT_3_5_TURBO -> {
- val paddingTokens = 5 // otherwise if the model changes, it might later fail
- val fallbackModel: LLMModel = LLMModel.GPT_3_5_TURBO_0301
- logger.debug {
- "Warning: ${model.name} may change over time. " +
- "Returning messages num tokens assuming ${fallbackModel.name} + $paddingTokens padding tokens."
- }
- tokensFromMessages(messages, fallbackModel) + paddingTokens
- }
- LLMModel.GPT_4,
- LLMModel.GPT_4_32K -> {
- val paddingTokens = 5 // otherwise if the model changes, it might later fail
- val fallbackModel: LLMModel = LLMModel.GPT_4_0314
- logger.debug {
- "Warning: ${model.name} may change over time. " +
- "Returning messages num tokens assuming ${fallbackModel.name} + $paddingTokens padding tokens."
- }
- tokensFromMessages(messages, fallbackModel) + paddingTokens
- }
- LLMModel.GPT_3_5_TURBO_0301 ->
- model.modelType.encoding.countTokensFromMessages(
- messages,
- tokensPerMessage = 4,
- tokensPerName = 0
- )
- LLMModel.GPT_4_0314 ->
- model.modelType.encoding.countTokensFromMessages(
- messages,
- tokensPerMessage = 3,
- tokensPerName = 2
- )
- else -> {
- val paddingTokens = 20
- val fallbackModel: LLMModel = LLMModel.GPT_3_5_TURBO_0301
- logger.debug {
- "Warning: calculation of tokens is partially supported for ${model.name} . " +
- "Returning messages num tokens assuming ${fallbackModel.name} + $paddingTokens padding tokens."
- }
- tokensFromMessages(messages, fallbackModel) + paddingTokens
- }
- }
-
-private fun Encoding.countTokensFromMessages(
- messages: List,
- tokensPerMessage: Int,
- tokensPerName: Int
-): Int =
- messages.sumOf { message ->
- countTokens(message.role) +
- countTokens(message.content) +
- tokensPerMessage +
- (message.name?.let { tokensPerName } ?: 0)
- } + 3
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/sql/DatabaseExample.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/sql/DatabaseExample.kt
index 1aab0d810..223aa46d1 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/sql/DatabaseExample.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/sql/DatabaseExample.kt
@@ -4,7 +4,6 @@ import arrow.core.raise.catch
import com.xebia.functional.tokenizer.ModelType
import com.xebia.functional.xef.auto.ai
import com.xebia.functional.xef.auto.getOrThrow
-import com.xebia.functional.xef.auto.promptMessage
import com.xebia.functional.xef.sql.SQL
import com.xebia.functional.xef.sql.jdbc.JdbcConfig
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Search.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Search.kt
index af70dbfba..6ecaa5d0e 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Search.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Search.kt
@@ -1,7 +1,6 @@
package com.xebia.functional.xef.auto.tot
import com.xebia.functional.xef.auto.AIScope
-import com.xebia.functional.xef.auto.promptMessage
suspend fun AIScope.generateSearchPrompts(problem: Problem): List =
promptMessage(
diff --git a/integrations/sql/src/main/kotlin/com/xebia/functional/xef/sql/SQL.kt b/integrations/sql/src/main/kotlin/com/xebia/functional/xef/sql/SQL.kt
index fe7e4f07b..420d97353 100644
--- a/integrations/sql/src/main/kotlin/com/xebia/functional/xef/sql/SQL.kt
+++ b/integrations/sql/src/main/kotlin/com/xebia/functional/xef/sql/SQL.kt
@@ -2,7 +2,6 @@ package com.xebia.functional.xef.sql
import com.xebia.functional.xef.auto.AIScope
import com.xebia.functional.xef.auto.AiDsl
-import com.xebia.functional.xef.auto.promptMessage
import com.xebia.functional.xef.sql.jdbc.JdbcConfig
import com.xebia.functional.xef.textsplitters.TokenTextSplitter
import io.github.oshai.kotlinlogging.KotlinLogging
diff --git a/scala/src/main/scala/com/xebia/functional/xef/scala/auto/package.scala b/scala/src/main/scala/com/xebia/functional/xef/scala/auto/package.scala
index 06715b43d..25bee1e68 100644
--- a/scala/src/main/scala/com/xebia/functional/xef/scala/auto/package.scala
+++ b/scala/src/main/scala/com/xebia/functional/xef/scala/auto/package.scala
@@ -6,14 +6,13 @@ import com.xebia.functional.xef.llm.openai.LLMModel
import com.xebia.functional.xef.llm.openai.functions.CFunction
import io.circe.Decoder
import io.circe.parser.parse
-import com.xebia.functional.xef.llm.openai.images.ImagesGenerationResponse
-import com.xebia.functional.xef.auto.{AIKt, Agent as KtAgent}
+import com.xebia.functional.xef.auto.AIKt
import com.xebia.functional.xef.auto.serialization.functions.FunctionSchemaKt
import com.xebia.functional.xef.pdf.PDFLoaderKt
import com.xebia.functional.tokenizer.ModelType
import com.xebia.functional.xef.llm.openai._
import com.xebia.functional.xef.scala.textsplitters.TextSplitter
-import scala.jdk.CollectionConverters._
+import com.xebia.functional.xef.llm.openai.images.*
import java.io.File
import scala.jdk.CollectionConverters.*
@@ -51,8 +50,7 @@ def prompt[A: Decoder: SerialDescriptor](
minResponseTokens: Int = 500
)(using scope: AIScope): A =
LoomAdapter.apply((cont) =>
- KtAgent.promptWithSerializer[A](
- scope.kt,
+ scope.kt.promptWithSerializer[A](
prompt,
FunctionSchemaKt.encodeFunctionSchema(SerialDescriptor[A].serialDescriptor),
(json: String) => parse(json).flatMap(Decoder[A].decodeJson(_)).fold(throw _, identity),
@@ -84,7 +82,7 @@ def promptMessage(
)(using scope: AIScope): List[String] =
LoomAdapter
.apply[java.util.List[String]](
- KtAgent.promptMessage(scope.kt, prompt, llmModel, functions.asJava, user, echo, n, temperature, bringFromContext, minResponseTokens, _)
+ scope.kt.promptMessage(prompt, llmModel, functions.asJava, user, echo, n, temperature, bringFromContext, minResponseTokens, _)
).asScala.toList
def pdf(
@@ -112,8 +110,7 @@ def images(
)(using scope: AIScope): List[String] =
LoomAdapter
.apply[ImagesGenerationResponse](cont =>
- KtAgent.images(
- scope.kt,
+ scope.kt.images(
prompt,
user,
n,