diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/AIError.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/AIError.kt
index 373ed3cf0..553b69a5c 100644
--- a/core/src/commonMain/kotlin/com/xebia/functional/xef/AIError.kt
+++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/AIError.kt
@@ -29,11 +29,7 @@ sealed interface AIError {
"Prompt exceeds max token length: $promptTokens + $maxTokens = ${promptTokens + maxTokens}"
}
- data class JsonParsing(
- val result: String,
- val maxAttempts: Int,
- val cause: IllegalArgumentException
- ) : AIError {
+ data class JsonParsing(val result: String, val maxAttempts: Int, val cause: Throwable) : AIError {
override val reason: String
get() = "Failed to parse the JSON response after $maxAttempts attempts: $result"
}
diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/DeserializerLLMAgent.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/DeserializerLLMAgent.kt
index f71687f52..71e34c0c5 100644
--- a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/DeserializerLLMAgent.kt
+++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/DeserializerLLMAgent.kt
@@ -1,5 +1,9 @@
+@file:JvmMultifileClass
+@file:JvmName("Agent")
+
package com.xebia.functional.xef.auto
+import arrow.core.nonFatalOrThrow
import arrow.core.raise.catch
import arrow.core.raise.ensureNotNull
import com.xebia.functional.xef.AIError
@@ -7,8 +11,11 @@ import com.xebia.functional.xef.auto.serialization.buildJsonSchema
import com.xebia.functional.xef.llm.openai.LLMModel
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.prompt.append
+import kotlin.jvm.JvmMultifileClass
+import kotlin.jvm.JvmName
import kotlinx.serialization.KSerializer
import kotlinx.serialization.SerializationException
+import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.json.Json
import kotlinx.serialization.serializer
@@ -99,14 +106,37 @@ suspend fun AIScope.prompt(
temperature: Double = 0.0,
bringFromContext: Int = 10,
minResponseTokens: Int = 500,
-): A {
- val serializationConfig: SerializationConfig =
- SerializationConfig(
- jsonSchema = buildJsonSchema(serializer.descriptor, false),
- descriptor = serializer.descriptor,
- deserializationStrategy = serializer
- )
+): A =
+ prompt(
+ prompt,
+ serializer.descriptor,
+ { json.decodeFromString(serializer, it) },
+ maxDeserializationAttempts,
+ model,
+ user,
+ echo,
+ n,
+ temperature,
+ bringFromContext,
+ minResponseTokens
+ )
+@AiDsl
+@JvmName("promptWithSerializer")
+suspend fun AIScope.prompt(
+ prompt: Prompt,
+ descriptor: SerialDescriptor,
+ serializer: (json: String) -> A,
+ maxDeserializationAttempts: Int = 5,
+ model: LLMModel = LLMModel.GPT_3_5_TURBO,
+ user: String = "testing",
+ echo: Boolean = false,
+ n: Int = 1,
+ temperature: Double = 0.0,
+ bringFromContext: Int = 10,
+ minResponseTokens: Int = 500,
+): A {
+ val jsonSchema = buildJsonSchema(descriptor, false)
val responseInstructions =
"""
|
@@ -116,12 +146,12 @@ suspend fun AIScope.prompt(
|3. Use the JSON schema to produce the result exclusively in valid JSON format.
|4. Pay attention to required vs non-required fields in the schema.
|JSON Schema:
- |${serializationConfig.jsonSchema}
+ |$jsonSchema
|Response:
"""
.trimMargin()
- return tryDeserialize(serializationConfig, json, maxDeserializationAttempts) {
+ return tryDeserialize(serializer, maxDeserializationAttempts) {
promptMessage(
prompt.append(responseInstructions),
model,
@@ -136,8 +166,7 @@ suspend fun AIScope.prompt(
}
suspend fun AIScope.tryDeserialize(
- serializationConfig: SerializationConfig,
- json: Json,
+ serializer: (json: String) -> A,
maxDeserializationAttempts: Int,
agent: AI>
): A {
@@ -146,13 +175,10 @@ suspend fun AIScope.tryDeserialize(
currentAttempts++
val result = ensureNotNull(agent().firstOrNull()) { AIError.NoResponse }
catch({
- return@tryDeserialize json.decodeFromString(
- serializationConfig.deserializationStrategy,
- result
- )
- }) { e: IllegalArgumentException ->
+ return@tryDeserialize serializer(result)
+ }) { e: Throwable ->
if (currentAttempts == maxDeserializationAttempts)
- raise(AIError.JsonParsing(result, maxDeserializationAttempts, e))
+ raise(AIError.JsonParsing(result, maxDeserializationAttempts, e.nonFatalOrThrow()))
// else continue with the next attempt
}
}
diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/ImageGenerationAgent.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/ImageGenerationAgent.kt
index df641b6fe..c6e6d6394 100644
--- a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/ImageGenerationAgent.kt
+++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/ImageGenerationAgent.kt
@@ -1,9 +1,14 @@
+@file:JvmMultifileClass
+@file:JvmName("Agent")
+
package com.xebia.functional.xef.auto
import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.llm.openai.ImagesGenerationRequest
import com.xebia.functional.xef.llm.openai.ImagesGenerationResponse
import com.xebia.functional.xef.prompt.Prompt
+import kotlin.jvm.JvmMultifileClass
+import kotlin.jvm.JvmName
/**
* Run a [prompt] describes the images you want to generate within the context of [AIScope].
diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/LLMAgent.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/LLMAgent.kt
index d0284728c..e2d9f584b 100644
--- a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/LLMAgent.kt
+++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/LLMAgent.kt
@@ -1,3 +1,6 @@
+@file:JvmMultifileClass
+@file:JvmName("Agent")
+
package com.xebia.functional.xef.auto
import arrow.core.raise.Raise
@@ -14,6 +17,8 @@ import com.xebia.functional.xef.llm.openai.Role
import com.xebia.functional.xef.prompt.Prompt
import io.github.oshai.KLogger
import io.github.oshai.KotlinLogging
+import kotlin.jvm.JvmMultifileClass
+import kotlin.jvm.JvmName
private val logger: KLogger by lazy { KotlinLogging.logger {} }
diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml
index d7023eb2d..5f5b2e0b9 100644
--- a/gradle/libs.versions.toml
+++ b/gradle/libs.versions.toml
@@ -82,6 +82,7 @@ doobie-postgres = { module = "org.tpolecat:doobie-postgres_3", version.ref = "do
doobie-hikari = { module = "org.tpolecat:doobie-hikari_3", version.ref = "doobie" }
doobie-munit = { module = "org.tpolecat:doobie-munit_3", version.ref = "doobie" }
circe = { module = "io.circe:circe-generic_3", version.ref = "circe" }
+circe-parser = { module = "io.circe:circe-parser_3", version.ref = "circe" }
cats-effect = { module = "org.typelevel:cats-effect_3", version.ref = "catsEffect" }
logger = { module = "org.typelevel:log4cats-slf4j_3", version.ref = "log4cats" }
openai = { module = "com.theokanning.openai-gpt3-java:service", version.ref = "openai" }
diff --git a/scala/build.gradle.kts b/scala/build.gradle.kts
index 9072d9ef3..0997f8017 100644
--- a/scala/build.gradle.kts
+++ b/scala/build.gradle.kts
@@ -16,6 +16,7 @@ dependencies {
implementation(libs.http4s.dsl)
implementation(libs.http4s.client)
implementation(libs.http4s.circe)
+ implementation(libs.circe.parser)
implementation(libs.http4s.emberClient)
implementation(libs.doobie.core)
implementation(libs.doobie.postgres)
diff --git a/scala/src/main/scala/com/xebia/functional/auto/AI.scala b/scala/src/main/scala/com/xebia/functional/auto/AI.scala
index 59a0fa29f..557cbd94e 100644
--- a/scala/src/main/scala/com/xebia/functional/auto/AI.scala
+++ b/scala/src/main/scala/com/xebia/functional/auto/AI.scala
@@ -1,17 +1,15 @@
package com.xebia.functional.auto
import com.xebia.functional.loom.LoomAdapter
+import com.xebia.functional.scala.auto.ScalaSerialDescriptor
import com.xebia.functional.xef.auto.AIScope as KtAIScope
+import com.xebia.functional.xef.auto.Agent as KtAgent
import com.xebia.functional.xef.auto.AIException
import com.xebia.functional.xef.auto.AIKt
import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.llm.openai.LLMModel
-
-//def example(using AIScope): String =
-// prompt[String]("What is your name?")
-
-//val example: AIScope ?=> String =
-// prompt[String]("What is your name?")
+import io.circe.{Decoder, Json}
+import io.circe.parser.parse
object AI:
@@ -20,6 +18,7 @@ object AI:
AIKt.AIScope[A](
{ (coreAIScope, cont) =>
given AIScope = AIScope.fromCore(coreAIScope)
+
block
},
(e: AIError, cont) => throw AIException(e.getReason),
@@ -29,20 +28,35 @@ object AI:
end AI
-final case class AIScope(kt: KtAIScope):
-
- // TODO: Design signature for Scala3 w/ Json parser (with support for generating Json Schema)?
- def prompt[A](
- prompt: String,
- maxAttempts: Int = 5,
- llmMode: LLMModel = LLMModel.getGPT_3_5_TURBO
- ): A = ???
-
- def promptMessage(
- prompt: String,
- maxAttempts: Int = 5,
- llmMode: LLMModel = LLMModel.getGPT_3_5_TURBO
- ): String = ???
-
+final case class AIScope(kt: KtAIScope)
private object AIScope:
def fromCore(coreAIScope: KtAIScope): AIScope = new AIScope(coreAIScope)
+
+def prompt[A: Decoder: ScalaSerialDescriptor](
+ prompt: String,
+ maxAttempts: Int = 5,
+ llmMode: LLMModel = LLMModel.getGPT_3_5_TURBO,
+ user: String = "testing",
+ echo: Boolean = false,
+ n: Int = 1,
+ temperature: Double = 0.0,
+ bringFromContext: Int = 10,
+ minResponseTokens: Int = 500
+)(using scope: AIScope): A =
+ LoomAdapter.apply((cont) =>
+ KtAgent.promptWithSerializer[A](
+ scope.kt,
+ prompt,
+ ScalaSerialDescriptor[A].serialDescriptor,
+ (json) => parse(json).flatMap(Decoder[A].decodeJson(_)).fold(throw _, identity),
+ maxAttempts,
+ llmMode,
+ user,
+ echo,
+ n,
+ temperature,
+ bringFromContext,
+ minResponseTokens,
+ cont
+ )
+ )