Skip to content

Commit

Permalink
Scala Json Parsing (#100)
Browse files Browse the repository at this point in the history
Co-authored-by: raulraja <raulraja@gmail.com>
Co-authored-by: Yago Cervantes <1420230+Yawolf@users.noreply.github.com>
  • Loading branch information
3 people authored May 24, 2023
1 parent eb5bf4f commit e35779e
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,7 @@ sealed interface AIError {
"Prompt exceeds max token length: $promptTokens + $maxTokens = ${promptTokens + maxTokens}"
}

data class JsonParsing(
val result: String,
val maxAttempts: Int,
val cause: IllegalArgumentException
) : AIError {
data class JsonParsing(val result: String, val maxAttempts: Int, val cause: Throwable) : AIError {
override val reason: String
get() = "Failed to parse the JSON response after $maxAttempts attempts: $result"
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
@file:JvmMultifileClass
@file:JvmName("Agent")

package com.xebia.functional.xef.auto

import arrow.core.nonFatalOrThrow
import arrow.core.raise.catch
import arrow.core.raise.ensureNotNull
import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.auto.serialization.buildJsonSchema
import com.xebia.functional.xef.llm.openai.LLMModel
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.prompt.append
import kotlin.jvm.JvmMultifileClass
import kotlin.jvm.JvmName
import kotlinx.serialization.KSerializer
import kotlinx.serialization.SerializationException
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.json.Json
import kotlinx.serialization.serializer

Expand Down Expand Up @@ -99,14 +106,37 @@ suspend fun <A> AIScope.prompt(
temperature: Double = 0.0,
bringFromContext: Int = 10,
minResponseTokens: Int = 500,
): A {
val serializationConfig: SerializationConfig<A> =
SerializationConfig(
jsonSchema = buildJsonSchema(serializer.descriptor, false),
descriptor = serializer.descriptor,
deserializationStrategy = serializer
)
): A =
prompt(
prompt,
serializer.descriptor,
{ json.decodeFromString(serializer, it) },
maxDeserializationAttempts,
model,
user,
echo,
n,
temperature,
bringFromContext,
minResponseTokens
)

@AiDsl
@JvmName("promptWithSerializer")
suspend fun <A> AIScope.prompt(
prompt: Prompt,
descriptor: SerialDescriptor,
serializer: (json: String) -> A,
maxDeserializationAttempts: Int = 5,
model: LLMModel = LLMModel.GPT_3_5_TURBO,
user: String = "testing",
echo: Boolean = false,
n: Int = 1,
temperature: Double = 0.0,
bringFromContext: Int = 10,
minResponseTokens: Int = 500,
): A {
val jsonSchema = buildJsonSchema(descriptor, false)
val responseInstructions =
"""
|
Expand All @@ -116,12 +146,12 @@ suspend fun <A> AIScope.prompt(
|3. Use the JSON schema to produce the result exclusively in valid JSON format.
|4. Pay attention to required vs non-required fields in the schema.
|JSON Schema:
|${serializationConfig.jsonSchema}
|$jsonSchema
|Response:
"""
.trimMargin()

return tryDeserialize(serializationConfig, json, maxDeserializationAttempts) {
return tryDeserialize(serializer, maxDeserializationAttempts) {
promptMessage(
prompt.append(responseInstructions),
model,
Expand All @@ -136,8 +166,7 @@ suspend fun <A> AIScope.prompt(
}

suspend fun <A> AIScope.tryDeserialize(
serializationConfig: SerializationConfig<A>,
json: Json,
serializer: (json: String) -> A,
maxDeserializationAttempts: Int,
agent: AI<List<String>>
): A {
Expand All @@ -146,13 +175,10 @@ suspend fun <A> AIScope.tryDeserialize(
currentAttempts++
val result = ensureNotNull(agent().firstOrNull()) { AIError.NoResponse }
catch({
return@tryDeserialize json.decodeFromString(
serializationConfig.deserializationStrategy,
result
)
}) { e: IllegalArgumentException ->
return@tryDeserialize serializer(result)
}) { e: Throwable ->
if (currentAttempts == maxDeserializationAttempts)
raise(AIError.JsonParsing(result, maxDeserializationAttempts, e))
raise(AIError.JsonParsing(result, maxDeserializationAttempts, e.nonFatalOrThrow()))
// else continue with the next attempt
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
@file:JvmMultifileClass
@file:JvmName("Agent")

package com.xebia.functional.xef.auto

import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.llm.openai.ImagesGenerationRequest
import com.xebia.functional.xef.llm.openai.ImagesGenerationResponse
import com.xebia.functional.xef.prompt.Prompt
import kotlin.jvm.JvmMultifileClass
import kotlin.jvm.JvmName

/**
* Run a [prompt] describes the images you want to generate within the context of [AIScope].
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
@file:JvmMultifileClass
@file:JvmName("Agent")

package com.xebia.functional.xef.auto

import arrow.core.raise.Raise
Expand All @@ -14,6 +17,8 @@ import com.xebia.functional.xef.llm.openai.Role
import com.xebia.functional.xef.prompt.Prompt
import io.github.oshai.KLogger
import io.github.oshai.KotlinLogging
import kotlin.jvm.JvmMultifileClass
import kotlin.jvm.JvmName

private val logger: KLogger by lazy { KotlinLogging.logger {} }

Expand Down
1 change: 1 addition & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ doobie-postgres = { module = "org.tpolecat:doobie-postgres_3", version.ref = "do
doobie-hikari = { module = "org.tpolecat:doobie-hikari_3", version.ref = "doobie" }
doobie-munit = { module = "org.tpolecat:doobie-munit_3", version.ref = "doobie" }
circe = { module = "io.circe:circe-generic_3", version.ref = "circe" }
circe-parser = { module = "io.circe:circe-parser_3", version.ref = "circe" }
cats-effect = { module = "org.typelevel:cats-effect_3", version.ref = "catsEffect" }
logger = { module = "org.typelevel:log4cats-slf4j_3", version.ref = "log4cats" }
openai = { module = "com.theokanning.openai-gpt3-java:service", version.ref = "openai" }
Expand Down
1 change: 1 addition & 0 deletions scala/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies {
implementation(libs.http4s.dsl)
implementation(libs.http4s.client)
implementation(libs.http4s.circe)
implementation(libs.circe.parser)
implementation(libs.http4s.emberClient)
implementation(libs.doobie.core)
implementation(libs.doobie.postgres)
Expand Down
56 changes: 35 additions & 21 deletions scala/src/main/scala/com/xebia/functional/auto/AI.scala
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
package com.xebia.functional.auto

import com.xebia.functional.loom.LoomAdapter
import com.xebia.functional.scala.auto.ScalaSerialDescriptor
import com.xebia.functional.xef.auto.AIScope as KtAIScope
import com.xebia.functional.xef.auto.Agent as KtAgent
import com.xebia.functional.xef.auto.AIException
import com.xebia.functional.xef.auto.AIKt
import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.llm.openai.LLMModel

//def example(using AIScope): String =
// prompt[String]("What is your name?")

//val example: AIScope ?=> String =
// prompt[String]("What is your name?")
import io.circe.{Decoder, Json}
import io.circe.parser.parse

object AI:

Expand All @@ -20,6 +18,7 @@ object AI:
AIKt.AIScope[A](
{ (coreAIScope, cont) =>
given AIScope = AIScope.fromCore(coreAIScope)

block
},
(e: AIError, cont) => throw AIException(e.getReason),
Expand All @@ -29,20 +28,35 @@ object AI:

end AI

final case class AIScope(kt: KtAIScope):

// TODO: Design signature for Scala3 w/ Json parser (with support for generating Json Schema)?
def prompt[A](
prompt: String,
maxAttempts: Int = 5,
llmMode: LLMModel = LLMModel.getGPT_3_5_TURBO
): A = ???

def promptMessage(
prompt: String,
maxAttempts: Int = 5,
llmMode: LLMModel = LLMModel.getGPT_3_5_TURBO
): String = ???

final case class AIScope(kt: KtAIScope)
private object AIScope:
def fromCore(coreAIScope: KtAIScope): AIScope = new AIScope(coreAIScope)

def prompt[A: Decoder: ScalaSerialDescriptor](
prompt: String,
maxAttempts: Int = 5,
llmMode: LLMModel = LLMModel.getGPT_3_5_TURBO,
user: String = "testing",
echo: Boolean = false,
n: Int = 1,
temperature: Double = 0.0,
bringFromContext: Int = 10,
minResponseTokens: Int = 500
)(using scope: AIScope): A =
LoomAdapter.apply((cont) =>
KtAgent.promptWithSerializer[A](
scope.kt,
prompt,
ScalaSerialDescriptor[A].serialDescriptor,
(json) => parse(json).flatMap(Decoder[A].decodeJson(_)).fold(throw _, identity),
maxAttempts,
llmMode,
user,
echo,
n,
temperature,
bringFromContext,
minResponseTokens,
cont
)
)

0 comments on commit e35779e

Please sign in to comment.