Skip to content

Commit

Permalink
Replace Raise with exception hierarchy (#171)
Browse files Browse the repository at this point in the history
  • Loading branch information
nomisRev authored Jun 8, 2023
1 parent 9cd1023 commit 4fcfca5
Show file tree
Hide file tree
Showing 13 changed files with 79 additions and 169 deletions.
43 changes: 18 additions & 25 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/AIError.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,40 @@ package com.xebia.functional.xef

import arrow.core.NonEmptyList
import com.xebia.functional.xef.llm.openai.Message
import kotlin.jvm.JvmOverloads

sealed interface AIError {
val reason: String
sealed class AIError @JvmOverloads constructor(message: String, cause: Throwable? = null) :
RuntimeException(message, cause) {

object NoResponse : AIError {
override val reason: String
get() = "No response from the AI"
}
class NoResponse : AIError("No response from the AI")

data class MessagesExceedMaxTokenLength(
val messages: List<Message>,
val promptTokens: Int,
val maxTokens: Int
) : AIError {
override val reason: String =
) :
AIError(
"Prompt exceeds max token length: $promptTokens + $maxTokens = ${promptTokens + maxTokens}"
}
)

data class PromptExceedsMaxTokenLength(
val prompt: String,
val promptTokens: Int,
val maxTokens: Int
) : AIError {
override val reason: String =
) :
AIError(
"Prompt exceeds max token length: $promptTokens + $maxTokens = ${promptTokens + maxTokens}"
}
)

data class JsonParsing(val result: String, val maxAttempts: Int, val cause: Throwable) : AIError {
override val reason: String
get() = "Failed to parse the JSON response after $maxAttempts attempts: $result"
}
data class JsonParsing(val result: String, val maxAttempts: Int, override val cause: Throwable) :
AIError("Failed to parse the JSON response after $maxAttempts attempts: $result", cause)

sealed interface Env : AIError {
data class OpenAI(val errors: NonEmptyList<String>) : Env {
override val reason: String
get() = "OpenAI Environment not found: ${errors.all.joinToString("\n")}"
}
sealed class Env @JvmOverloads constructor(message: String, cause: Throwable? = null) :
AIError(message, cause) {
data class OpenAI(val errors: NonEmptyList<String>) :
Env("OpenAI Environment not found: ${errors.all.joinToString("\n")}")

data class HuggingFace(val errors: NonEmptyList<String>) : Env {
override val reason: String
get() = "HuggingFace Environment not found: ${errors.all.joinToString("\n")}"
}
data class HuggingFace(val errors: NonEmptyList<String>) :
Env("HuggingFace Environment not found: ${errors.all.joinToString("\n")}")
}
}
31 changes: 8 additions & 23 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AI.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package com.xebia.functional.xef.auto
import arrow.core.Either
import arrow.core.left
import arrow.core.raise.Raise
import arrow.core.raise.recover
import arrow.core.right
import arrow.fx.coroutines.ResourceScope
import arrow.fx.coroutines.resourceScope
Expand All @@ -21,18 +20,9 @@ import io.github.oshai.kotlinlogging.KLogger
import io.github.oshai.kotlinlogging.KotlinLogging
import kotlin.jvm.JvmName
import kotlin.time.ExperimentalTime
import kotlinx.serialization.DeserializationStrategy
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.json.JsonObject

@DslMarker annotation class AiDsl

data class SerializationConfig<A>(
val jsonSchema: JsonObject,
val descriptor: SerialDescriptor,
val deserializationStrategy: DeserializationStrategy<A>,
)

/**
* An [AI] value represents an action relying on artificial intelligence that can be run to produce
* an [A]. This value is _lazy_ and can be combined with other `AI` values using [AIScope.invoke],
Expand All @@ -57,18 +47,18 @@ suspend inline fun <A> AI<A>.getOrElse(crossinline orElse: suspend (AIError) ->

@OptIn(ExperimentalTime::class)
suspend fun <A> AIScope(block: suspend AIScope.() -> A, orElse: suspend (AIError) -> A): A =
recover({
try {
resourceScope {
val openAIConfig = OpenAIConfig()
val openAiClient: OpenAIClient = KtorOpenAIClient(openAIConfig)
val logger = KotlinLogging.logger("AutoAI")
val embeddings = OpenAIEmbeddings(openAIConfig, openAiClient, logger)
val vectorStore = LocalVectorStore(embeddings)
val scope = AIScope(openAiClient, vectorStore, embeddings, logger, this, this@recover)
val scope = AIScope(openAiClient, vectorStore, embeddings, logger, this)
block(scope)
}
}) {
orElse(it)
} catch (e: AIError) {
orElse(e)
}

/**
Expand All @@ -82,9 +72,6 @@ suspend fun <A> AIScope(block: suspend AIScope.() -> A, orElse: suspend (AIError
suspend inline fun <reified A> AI<A>.toEither(): Either<AIError, A> =
ai { invoke().right() }.getOrElse { it.left() }

// TODO: Allow traced transformation of Raise errors
class AIException(message: String) : RuntimeException(message)

/**
* Run the [AI] value to produce [A]. this method initialises all the dependencies required to run
* the [AI] value and once it finishes it closes all the resources.
Expand All @@ -95,7 +82,7 @@ class AIException(message: String) : RuntimeException(message)
* @see getOrElse for an operator that allow directly handling the [AIError] case instead of
* throwing.
*/
suspend inline fun <reified A> AI<A>.getOrThrow(): A = getOrElse { throw AIException(it.reason) }
suspend inline fun <reified A> AI<A>.getOrThrow(): A = getOrElse { throw it }

/**
* The [AIScope] is the context in which [AI] values are run. It encapsulates all the dependencies
Expand All @@ -109,9 +96,8 @@ class AIScope(
val context: VectorStore,
internal val embeddings: Embeddings,
private val logger: KLogger,
resourceScope: ResourceScope,
raise: Raise<AIError>,
) : ResourceScope by resourceScope, Raise<AIError> by raise {
resourceScope: ResourceScope
) : ResourceScope by resourceScope {

/**
* Allows invoking [AI] values in the context of this [AIScope].
Expand Down Expand Up @@ -169,8 +155,7 @@ class AIScope(
CombinedVectorStore(newStore, this@AIScope.context),
this@AIScope.embeddings,
this@AIScope.logger,
this,
this@AIScope
this
)
.block()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package com.xebia.functional.xef.auto

import arrow.core.nonFatalOrThrow
import arrow.core.raise.catch
import arrow.core.raise.ensureNotNull
import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.auto.serialization.buildJsonSchema
import com.xebia.functional.xef.llm.openai.LLMModel
Expand Down Expand Up @@ -172,14 +171,14 @@ suspend fun <A> AIScope.tryDeserialize(
agent: AI<List<String>>
): A {
(0 until maxDeserializationAttempts).forEach { currentAttempts ->
val result = ensureNotNull(agent().firstOrNull()) { AIError.NoResponse }
val result = agent().firstOrNull() ?: throw AIError.NoResponse()
catch({
return@tryDeserialize serializer(result)
}) { e: Throwable ->
if (currentAttempts == maxDeserializationAttempts)
raise(AIError.JsonParsing(result, maxDeserializationAttempts, e.nonFatalOrThrow()))
// else continue with the next attempt
throw AIError.JsonParsing(result, maxDeserializationAttempts, e.nonFatalOrThrow())
// TODO else log attempt ?
}
}
raise(AIError.NoResponse)
throw AIError.NoResponse()
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ suspend inline fun <reified A> AIScope.image(
bringFromContext: Int = 10
): A {
val imageResponse = images(prompt, user, 1, size, bringFromContext)
val url = imageResponse.data.firstOrNull() ?: raise(AIError.NoResponse)
val url = imageResponse.data.firstOrNull() ?: throw AIError.NoResponse()
return prompt<A>(
"""|Instructions: Format this [URL] and [PROMPT] information in the desired JSON response format
|specified at the end of the message.
Expand Down Expand Up @@ -113,7 +113,7 @@ suspend fun <A> AIScope.image(
minResponseTokens: Int = 500
): A {
val imageResponse = images(prompt, user, 1, size, bringFromContext)
val url = imageResponse.data.firstOrNull() ?: raise(AIError.NoResponse)
val url = imageResponse.data.firstOrNull() ?: throw AIError.NoResponse()
return prompt(
Prompt(
"""|Instructions: Format this [URL] and [PROMPT] information in the desired JSON response format
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

package com.xebia.functional.xef.auto

import arrow.core.raise.Raise
import arrow.core.raise.ensure
import com.xebia.functional.tokenizer.Encoding
import com.xebia.functional.tokenizer.ModelType
import com.xebia.functional.tokenizer.truncateText
Expand Down Expand Up @@ -80,7 +78,7 @@ suspend fun AIScope.promptMessage(
}
}

private fun Raise<AIError>.createPromptWithContextAwareOfTokens(
private fun createPromptWithContextAwareOfTokens(
ctxInfo: List<String>,
modelType: ModelType,
prompt: String,
Expand All @@ -93,8 +91,8 @@ private fun Raise<AIError>.createPromptWithContextAwareOfTokens(
return if (ctxInfo.isNotEmpty() && remainingTokens > minResponseTokens) {
val ctx: String = ctxInfo.joinToString("\n")

ensure(promptTokens < maxContextLength) {
raise(AIError.PromptExceedsMaxTokenLength(prompt, promptTokens, maxContextLength))
if (promptTokens >= maxContextLength) {
throw AIError.PromptExceedsMaxTokenLength(prompt, promptTokens, maxContextLength)
}
// truncate the context if it's too long based on the max tokens calculated considering the
// existing prompt tokens
Expand Down Expand Up @@ -184,7 +182,7 @@ private suspend fun AIScope.promptWithContext(
)
}

private fun AIScope.checkTotalLeftTokens(
private fun checkTotalLeftTokens(
modelType: ModelType,
role: String,
promptWithContext: String
Expand All @@ -196,7 +194,7 @@ private fun AIScope.checkTotalLeftTokens(
val takenTokens: Int = roleTokens + promptTokens + padding
val totalLeftTokens: Int = maxContextLength - takenTokens
if (totalLeftTokens < 0) {
raise(AIError.PromptExceedsMaxTokenLength(promptWithContext, takenTokens, maxContextLength))
throw AIError.PromptExceedsMaxTokenLength(promptWithContext, takenTokens, maxContextLength)
}
logger.debug {
"Tokens -- used: $takenTokens, model max: $maxContextLength, left: $totalLeftTokens"
Expand All @@ -209,7 +207,7 @@ private fun AIScope.checkTotalLeftChatTokens(messages: List<Message>, model: LLM
val messagesTokens: Int = tokensFromMessages(messages, model)
val totalLeftTokens: Int = maxContextLength - messagesTokens
if (totalLeftTokens < 0) {
raise(AIError.MessagesExceedMaxTokenLength(messages, messagesTokens, maxContextLength))
throw AIError.MessagesExceedMaxTokenLength(messages, messagesTokens, maxContextLength)
}
logger.debug {
"Tokens -- used: $messagesTokens, model max: $maxContextLength, left: $totalLeftTokens"
Expand Down
71 changes: 18 additions & 53 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/env/config.kt
Original file line number Diff line number Diff line change
@@ -1,66 +1,31 @@
package com.xebia.functional.xef.env

import arrow.core.NonEmptyList
import arrow.core.raise.Raise
import arrow.core.raise.catch
import arrow.core.raise.recover
import arrow.core.raise.zipOrAccumulate
import arrow.resilience.Schedule
import com.xebia.functional.xef.AIError
import io.ktor.http.Url as KUrl
import kotlin.jvm.JvmOverloads
import kotlin.time.Duration
import kotlin.time.Duration.Companion.seconds

data class Env(val openAI: OpenAIConfig, val huggingFace: HuggingFaceConfig)

data class OpenAIConfig(
val token: String,
val baseUrl: KUrl,
val chunkSize: Int,
val retryConfig: RetryConfig,
val requestTimeout: Duration
data class OpenAIConfig
@JvmOverloads
constructor(
val token: String =
requireNotNull(getenv("OPENAI_TOKEN")) { "OpenAI Token missing from environment." },
val baseUrl: String = "https://api.openai.com/v1/",
val chunkSize: Int = 300,
val retryConfig: RetryConfig = RetryConfig(),
val requestTimeoutMillis: Long = 30_000
)

data class RetryConfig(val backoff: Duration, val maxRetries: Long) {
data class RetryConfig(val backoff: Duration = 5.seconds, val maxRetries: Long = 5) {
fun schedule(): Schedule<Throwable, Long> =
Schedule.recurs<Throwable>(maxRetries)
.zipLeft(Schedule.exponential<Throwable>(backoff).jittered(0.75, 1.25))
}

data class HuggingFaceConfig(val token: String, val baseUrl: KUrl)

fun Raise<NonEmptyList<AIError.Env>>.Env(): Env =
zipOrAccumulate({ OpenAIConfig() }, { withNel { HuggingFaceConfig() } }) { openAI, huggingFace ->
Env(openAI, huggingFace)
}

fun Raise<AIError.Env.OpenAI>.OpenAIConfig(token: String? = null) =
recover({
zipOrAccumulate(
{ token ?: env("OPENAI_TOKEN") },
{ env("OPENAI_BASE_URI", default = Url("https://api.openai.com/v1/")) { Url(it) } },
{ env("OPENAI_CHUNK_SIZE", default = 300) { it.toIntOrNull() } },
{ env("OPENAI_BACKOFF", default = 5.seconds) { it.toIntOrNull()?.seconds } },
{ env("OPENAI_MAX_RETRIES", default = 5) { it.toLongOrNull() } },
{ env("OPENAI_REQUEST_TIMEOUT", default = 30.seconds) { it.toIntOrNull()?.seconds } },
) { token2, baseUrl, chunkSize, backoff, maxRetries, requestTimeout ->
OpenAIConfig(token2, baseUrl, chunkSize, RetryConfig(backoff, maxRetries), requestTimeout)
}
}) { e: NonEmptyList<String> ->
raise(AIError.Env.OpenAI(e))
}

fun Raise<AIError.Env.HuggingFace>.HuggingFaceConfig(token: String? = null) =
recover({
zipOrAccumulate(
{ token ?: env("HF_TOKEN") },
{ env("HF_BASE_URI", default = Url("https://api-inference.huggingface.co/")) { Url(it) } }
) { token2, baseUrl ->
HuggingFaceConfig(token2, baseUrl)
}
}) { e: NonEmptyList<String> ->
raise(AIError.Env.HuggingFace(e))
}

fun Raise<String>.Url(urlString: String): KUrl =
catch({ KUrl(urlString) }) { raise(it.message ?: "Invalid url: $it") }
data class HuggingFaceConfig
@JvmOverloads
constructor(
val token: String =
requireNotNull(getenv("OPENAI_TOKEN")) { "OpenAI Token missing from environment." },
val baseUrl: String = "https://api-inference.huggingface.co/"
)
41 changes: 0 additions & 41 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/env/env.kt

This file was deleted.

12 changes: 12 additions & 0 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/env/getenv.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package com.xebia.functional.xef.env

/**
* A function that reads the configuration from the environment. This only works on JVM, Native and
* NodeJS.
*
* In the browser, we default to `null` so either rely on the default values, or provide construct
* the values explicitly.
*
* We might be able to support browser through webpack.
*/
expect fun getenv(name: String): String?
Loading

0 comments on commit 4fcfca5

Please sign in to comment.