diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/AIError.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/AIError.kt index 373ed3cf0..553b69a5c 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/AIError.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/AIError.kt @@ -29,11 +29,7 @@ sealed interface AIError { "Prompt exceeds max token length: $promptTokens + $maxTokens = ${promptTokens + maxTokens}" } - data class JsonParsing( - val result: String, - val maxAttempts: Int, - val cause: IllegalArgumentException - ) : AIError { + data class JsonParsing(val result: String, val maxAttempts: Int, val cause: Throwable) : AIError { override val reason: String get() = "Failed to parse the JSON response after $maxAttempts attempts: $result" } diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/DeserializerLLMAgent.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/DeserializerLLMAgent.kt index f71687f52..71e34c0c5 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/DeserializerLLMAgent.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/DeserializerLLMAgent.kt @@ -1,5 +1,9 @@ +@file:JvmMultifileClass +@file:JvmName("Agent") + package com.xebia.functional.xef.auto +import arrow.core.nonFatalOrThrow import arrow.core.raise.catch import arrow.core.raise.ensureNotNull import com.xebia.functional.xef.AIError @@ -7,8 +11,11 @@ import com.xebia.functional.xef.auto.serialization.buildJsonSchema import com.xebia.functional.xef.llm.openai.LLMModel import com.xebia.functional.xef.prompt.Prompt import com.xebia.functional.xef.prompt.append +import kotlin.jvm.JvmMultifileClass +import kotlin.jvm.JvmName import kotlinx.serialization.KSerializer import kotlinx.serialization.SerializationException +import kotlinx.serialization.descriptors.SerialDescriptor import kotlinx.serialization.json.Json import kotlinx.serialization.serializer @@ -99,14 +106,37 @@ suspend fun AIScope.prompt( temperature: Double = 0.0, bringFromContext: Int = 10, minResponseTokens: Int = 500, -): A { - val serializationConfig: SerializationConfig = - SerializationConfig( - jsonSchema = buildJsonSchema(serializer.descriptor, false), - descriptor = serializer.descriptor, - deserializationStrategy = serializer - ) +): A = + prompt( + prompt, + serializer.descriptor, + { json.decodeFromString(serializer, it) }, + maxDeserializationAttempts, + model, + user, + echo, + n, + temperature, + bringFromContext, + minResponseTokens + ) +@AiDsl +@JvmName("promptWithSerializer") +suspend fun AIScope.prompt( + prompt: Prompt, + descriptor: SerialDescriptor, + serializer: (json: String) -> A, + maxDeserializationAttempts: Int = 5, + model: LLMModel = LLMModel.GPT_3_5_TURBO, + user: String = "testing", + echo: Boolean = false, + n: Int = 1, + temperature: Double = 0.0, + bringFromContext: Int = 10, + minResponseTokens: Int = 500, +): A { + val jsonSchema = buildJsonSchema(descriptor, false) val responseInstructions = """ | @@ -116,12 +146,12 @@ suspend fun AIScope.prompt( |3. Use the JSON schema to produce the result exclusively in valid JSON format. |4. Pay attention to required vs non-required fields in the schema. |JSON Schema: - |${serializationConfig.jsonSchema} + |$jsonSchema |Response: """ .trimMargin() - return tryDeserialize(serializationConfig, json, maxDeserializationAttempts) { + return tryDeserialize(serializer, maxDeserializationAttempts) { promptMessage( prompt.append(responseInstructions), model, @@ -136,8 +166,7 @@ suspend fun AIScope.prompt( } suspend fun AIScope.tryDeserialize( - serializationConfig: SerializationConfig, - json: Json, + serializer: (json: String) -> A, maxDeserializationAttempts: Int, agent: AI> ): A { @@ -146,13 +175,10 @@ suspend fun AIScope.tryDeserialize( currentAttempts++ val result = ensureNotNull(agent().firstOrNull()) { AIError.NoResponse } catch({ - return@tryDeserialize json.decodeFromString( - serializationConfig.deserializationStrategy, - result - ) - }) { e: IllegalArgumentException -> + return@tryDeserialize serializer(result) + }) { e: Throwable -> if (currentAttempts == maxDeserializationAttempts) - raise(AIError.JsonParsing(result, maxDeserializationAttempts, e)) + raise(AIError.JsonParsing(result, maxDeserializationAttempts, e.nonFatalOrThrow())) // else continue with the next attempt } } diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/ImageGenerationAgent.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/ImageGenerationAgent.kt index df641b6fe..c6e6d6394 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/ImageGenerationAgent.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/ImageGenerationAgent.kt @@ -1,9 +1,14 @@ +@file:JvmMultifileClass +@file:JvmName("Agent") + package com.xebia.functional.xef.auto import com.xebia.functional.xef.AIError import com.xebia.functional.xef.llm.openai.ImagesGenerationRequest import com.xebia.functional.xef.llm.openai.ImagesGenerationResponse import com.xebia.functional.xef.prompt.Prompt +import kotlin.jvm.JvmMultifileClass +import kotlin.jvm.JvmName /** * Run a [prompt] describes the images you want to generate within the context of [AIScope]. diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/LLMAgent.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/LLMAgent.kt index d0284728c..e2d9f584b 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/LLMAgent.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/LLMAgent.kt @@ -1,3 +1,6 @@ +@file:JvmMultifileClass +@file:JvmName("Agent") + package com.xebia.functional.xef.auto import arrow.core.raise.Raise @@ -14,6 +17,8 @@ import com.xebia.functional.xef.llm.openai.Role import com.xebia.functional.xef.prompt.Prompt import io.github.oshai.KLogger import io.github.oshai.KotlinLogging +import kotlin.jvm.JvmMultifileClass +import kotlin.jvm.JvmName private val logger: KLogger by lazy { KotlinLogging.logger {} } diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index d7023eb2d..5f5b2e0b9 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -82,6 +82,7 @@ doobie-postgres = { module = "org.tpolecat:doobie-postgres_3", version.ref = "do doobie-hikari = { module = "org.tpolecat:doobie-hikari_3", version.ref = "doobie" } doobie-munit = { module = "org.tpolecat:doobie-munit_3", version.ref = "doobie" } circe = { module = "io.circe:circe-generic_3", version.ref = "circe" } +circe-parser = { module = "io.circe:circe-parser_3", version.ref = "circe" } cats-effect = { module = "org.typelevel:cats-effect_3", version.ref = "catsEffect" } logger = { module = "org.typelevel:log4cats-slf4j_3", version.ref = "log4cats" } openai = { module = "com.theokanning.openai-gpt3-java:service", version.ref = "openai" } diff --git a/scala/build.gradle.kts b/scala/build.gradle.kts index 9072d9ef3..0997f8017 100644 --- a/scala/build.gradle.kts +++ b/scala/build.gradle.kts @@ -16,6 +16,7 @@ dependencies { implementation(libs.http4s.dsl) implementation(libs.http4s.client) implementation(libs.http4s.circe) + implementation(libs.circe.parser) implementation(libs.http4s.emberClient) implementation(libs.doobie.core) implementation(libs.doobie.postgres) diff --git a/scala/src/main/scala/com/xebia/functional/auto/AI.scala b/scala/src/main/scala/com/xebia/functional/auto/AI.scala index 59a0fa29f..557cbd94e 100644 --- a/scala/src/main/scala/com/xebia/functional/auto/AI.scala +++ b/scala/src/main/scala/com/xebia/functional/auto/AI.scala @@ -1,17 +1,15 @@ package com.xebia.functional.auto import com.xebia.functional.loom.LoomAdapter +import com.xebia.functional.scala.auto.ScalaSerialDescriptor import com.xebia.functional.xef.auto.AIScope as KtAIScope +import com.xebia.functional.xef.auto.Agent as KtAgent import com.xebia.functional.xef.auto.AIException import com.xebia.functional.xef.auto.AIKt import com.xebia.functional.xef.AIError import com.xebia.functional.xef.llm.openai.LLMModel - -//def example(using AIScope): String = -// prompt[String]("What is your name?") - -//val example: AIScope ?=> String = -// prompt[String]("What is your name?") +import io.circe.{Decoder, Json} +import io.circe.parser.parse object AI: @@ -20,6 +18,7 @@ object AI: AIKt.AIScope[A]( { (coreAIScope, cont) => given AIScope = AIScope.fromCore(coreAIScope) + block }, (e: AIError, cont) => throw AIException(e.getReason), @@ -29,20 +28,35 @@ object AI: end AI -final case class AIScope(kt: KtAIScope): - - // TODO: Design signature for Scala3 w/ Json parser (with support for generating Json Schema)? - def prompt[A]( - prompt: String, - maxAttempts: Int = 5, - llmMode: LLMModel = LLMModel.getGPT_3_5_TURBO - ): A = ??? - - def promptMessage( - prompt: String, - maxAttempts: Int = 5, - llmMode: LLMModel = LLMModel.getGPT_3_5_TURBO - ): String = ??? - +final case class AIScope(kt: KtAIScope) private object AIScope: def fromCore(coreAIScope: KtAIScope): AIScope = new AIScope(coreAIScope) + +def prompt[A: Decoder: ScalaSerialDescriptor]( + prompt: String, + maxAttempts: Int = 5, + llmMode: LLMModel = LLMModel.getGPT_3_5_TURBO, + user: String = "testing", + echo: Boolean = false, + n: Int = 1, + temperature: Double = 0.0, + bringFromContext: Int = 10, + minResponseTokens: Int = 500 +)(using scope: AIScope): A = + LoomAdapter.apply((cont) => + KtAgent.promptWithSerializer[A]( + scope.kt, + prompt, + ScalaSerialDescriptor[A].serialDescriptor, + (json) => parse(json).flatMap(Decoder[A].decodeJson(_)).fold(throw _, identity), + maxAttempts, + llmMode, + user, + echo, + n, + temperature, + bringFromContext, + minResponseTokens, + cont + ) + )