Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scala Json Parsing #100

Merged
merged 10 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,7 @@ sealed interface AIError {
"Prompt exceeds max token length: $promptTokens + $maxTokens = ${promptTokens + maxTokens}"
}

data class JsonParsing(
val result: String,
val maxAttempts: Int,
val cause: IllegalArgumentException
) : AIError {
data class JsonParsing(val result: String, val maxAttempts: Int, val cause: Throwable) : AIError {
override val reason: String
get() = "Failed to parse the JSON response after $maxAttempts attempts: $result"
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
@file:JvmMultifileClass
@file:JvmName("Agent")

package com.xebia.functional.xef.auto

import arrow.core.nonFatalOrThrow
import arrow.core.raise.catch
import arrow.core.raise.ensureNotNull
import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.auto.serialization.buildJsonSchema
import com.xebia.functional.xef.llm.openai.LLMModel
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.prompt.append
import kotlin.jvm.JvmMultifileClass
import kotlin.jvm.JvmName
import kotlinx.serialization.KSerializer
import kotlinx.serialization.SerializationException
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.json.Json
import kotlinx.serialization.serializer

Expand Down Expand Up @@ -99,14 +106,37 @@ suspend fun <A> AIScope.prompt(
temperature: Double = 0.0,
bringFromContext: Int = 10,
minResponseTokens: Int = 500,
): A {
val serializationConfig: SerializationConfig<A> =
SerializationConfig(
jsonSchema = buildJsonSchema(serializer.descriptor, false),
descriptor = serializer.descriptor,
deserializationStrategy = serializer
)
): A =
prompt(
prompt,
serializer.descriptor,
{ json.decodeFromString(serializer, it) },
maxDeserializationAttempts,
model,
user,
echo,
n,
temperature,
bringFromContext,
minResponseTokens
)

@AiDsl
@JvmName("promptWithSerializer")
suspend fun <A> AIScope.prompt(
prompt: Prompt,
descriptor: SerialDescriptor,
serializer: (json: String) -> A,
maxDeserializationAttempts: Int = 5,
model: LLMModel = LLMModel.GPT_3_5_TURBO,
user: String = "testing",
echo: Boolean = false,
n: Int = 1,
temperature: Double = 0.0,
bringFromContext: Int = 10,
minResponseTokens: Int = 500,
): A {
val jsonSchema = buildJsonSchema(descriptor, false)
val responseInstructions =
"""
|
Expand All @@ -116,12 +146,12 @@ suspend fun <A> AIScope.prompt(
|3. Use the JSON schema to produce the result exclusively in valid JSON format.
|4. Pay attention to required vs non-required fields in the schema.
|JSON Schema:
|${serializationConfig.jsonSchema}
|$jsonSchema
|Response:
"""
.trimMargin()

return tryDeserialize(serializationConfig, json, maxDeserializationAttempts) {
return tryDeserialize(serializer, maxDeserializationAttempts) {
promptMessage(
prompt.append(responseInstructions),
model,
Expand All @@ -136,8 +166,7 @@ suspend fun <A> AIScope.prompt(
}

suspend fun <A> AIScope.tryDeserialize(
serializationConfig: SerializationConfig<A>,
json: Json,
serializer: (json: String) -> A,
maxDeserializationAttempts: Int,
agent: AI<List<String>>
): A {
Expand All @@ -146,13 +175,10 @@ suspend fun <A> AIScope.tryDeserialize(
currentAttempts++
val result = ensureNotNull(agent().firstOrNull()) { AIError.NoResponse }
catch({
return@tryDeserialize json.decodeFromString(
serializationConfig.deserializationStrategy,
result
)
}) { e: IllegalArgumentException ->
return@tryDeserialize serializer(result)
}) { e: Throwable ->
if (currentAttempts == maxDeserializationAttempts)
raise(AIError.JsonParsing(result, maxDeserializationAttempts, e))
raise(AIError.JsonParsing(result, maxDeserializationAttempts, e.nonFatalOrThrow()))
// else continue with the next attempt
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
@file:JvmMultifileClass
@file:JvmName("Agent")

package com.xebia.functional.xef.auto

import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.llm.openai.ImagesGenerationRequest
import com.xebia.functional.xef.llm.openai.ImagesGenerationResponse
import com.xebia.functional.xef.prompt.Prompt
import kotlin.jvm.JvmMultifileClass
import kotlin.jvm.JvmName

/**
* Run a [prompt] describes the images you want to generate within the context of [AIScope].
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
@file:JvmMultifileClass
@file:JvmName("Agent")

package com.xebia.functional.xef.auto

import arrow.core.raise.Raise
Expand All @@ -14,6 +17,8 @@ import com.xebia.functional.xef.llm.openai.Role
import com.xebia.functional.xef.prompt.Prompt
import io.github.oshai.KLogger
import io.github.oshai.KotlinLogging
import kotlin.jvm.JvmMultifileClass
import kotlin.jvm.JvmName

private val logger: KLogger by lazy { KotlinLogging.logger {} }

Expand Down
1 change: 1 addition & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ doobie-postgres = { module = "org.tpolecat:doobie-postgres_3", version.ref = "do
doobie-hikari = { module = "org.tpolecat:doobie-hikari_3", version.ref = "doobie" }
doobie-munit = { module = "org.tpolecat:doobie-munit_3", version.ref = "doobie" }
circe = { module = "io.circe:circe-generic_3", version.ref = "circe" }
circe-parser = { module = "io.circe:circe-parser_3", version.ref = "circe" }
cats-effect = { module = "org.typelevel:cats-effect_3", version.ref = "catsEffect" }
logger = { module = "org.typelevel:log4cats-slf4j_3", version.ref = "log4cats" }
openai = { module = "com.theokanning.openai-gpt3-java:service", version.ref = "openai" }
Expand Down
1 change: 1 addition & 0 deletions scala/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies {
implementation(libs.http4s.dsl)
implementation(libs.http4s.client)
implementation(libs.http4s.circe)
implementation(libs.circe.parser)
implementation(libs.http4s.emberClient)
implementation(libs.doobie.core)
implementation(libs.doobie.postgres)
Expand Down
56 changes: 35 additions & 21 deletions scala/src/main/scala/com/xebia/functional/auto/AI.scala
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
package com.xebia.functional.auto

import com.xebia.functional.loom.LoomAdapter
import com.xebia.functional.scala.auto.ScalaSerialDescriptor
import com.xebia.functional.xef.auto.AIScope as KtAIScope
import com.xebia.functional.xef.auto.Agent as KtAgent
import com.xebia.functional.xef.auto.AIException
import com.xebia.functional.xef.auto.AIKt
import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.llm.openai.LLMModel

//def example(using AIScope): String =
// prompt[String]("What is your name?")

//val example: AIScope ?=> String =
// prompt[String]("What is your name?")
import io.circe.{Decoder, Json}
import io.circe.parser.parse

object AI:

Expand All @@ -20,6 +18,7 @@ object AI:
AIKt.AIScope[A](
{ (coreAIScope, cont) =>
given AIScope = AIScope.fromCore(coreAIScope)

block
},
(e: AIError, cont) => throw AIException(e.getReason),
Expand All @@ -29,20 +28,35 @@ object AI:

end AI

final case class AIScope(kt: KtAIScope):

// TODO: Design signature for Scala3 w/ Json parser (with support for generating Json Schema)?
def prompt[A](
prompt: String,
maxAttempts: Int = 5,
llmMode: LLMModel = LLMModel.getGPT_3_5_TURBO
): A = ???

def promptMessage(
prompt: String,
maxAttempts: Int = 5,
llmMode: LLMModel = LLMModel.getGPT_3_5_TURBO
): String = ???

final case class AIScope(kt: KtAIScope)
private object AIScope:
def fromCore(coreAIScope: KtAIScope): AIScope = new AIScope(coreAIScope)

def prompt[A: Decoder: ScalaSerialDescriptor](
prompt: String,
maxAttempts: Int = 5,
llmMode: LLMModel = LLMModel.getGPT_3_5_TURBO,
user: String = "testing",
echo: Boolean = false,
n: Int = 1,
temperature: Double = 0.0,
bringFromContext: Int = 10,
minResponseTokens: Int = 500
)(using scope: AIScope): A =
LoomAdapter.apply((cont) =>
KtAgent.promptWithSerializer[A](
scope.kt,
prompt,
ScalaSerialDescriptor[A].serialDescriptor,
(json) => parse(json).flatMap(Decoder[A].decodeJson(_)).fold(throw _, identity),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with this, but it failed in the previous build because io.circe.parse.decoder is needed 🙂

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I misunderstood your suggestion, but that was failing compilation. Feel free to update this to whatever you feel is most idiomatic ☺️

Decoder.decodeJson takes a Json and not a String, so I used parse first.

maxAttempts,
llmMode,
user,
echo,
n,
temperature,
bringFromContext,
minResponseTokens,
cont
)
)