Skip to content

Commit

Permalink
More possibilities for contexts (#73)
Browse files Browse the repository at this point in the history
* More possibilities for contexts

* Make Spotless happy

* Fix examples
  • Loading branch information
serras authored May 17, 2023
1 parent 18ec7f2 commit 3ac51b2
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 18 deletions.
58 changes: 45 additions & 13 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AI.kt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ import com.xebia.functional.xef.llm.openai.KtorOpenAIClient
import com.xebia.functional.xef.llm.openai.LLMModel
import com.xebia.functional.xef.llm.openai.OpenAIClient
import com.xebia.functional.xef.prompt.PromptTemplate
import com.xebia.functional.xef.vectorstores.CombinedVectorStore
import com.xebia.functional.xef.vectorstores.LocalVectorStore
import com.xebia.functional.xef.vectorstores.LocalVectorStoreBuilder
import com.xebia.functional.xef.vectorstores.VectorStore
import io.github.oshai.KLogger
import io.github.oshai.KotlinLogging
Expand Down Expand Up @@ -156,14 +158,13 @@ class AIScope(
*/
@AiDsl @JvmName("invokeAI") suspend operator fun <A> AI<A>.invoke(): A = invoke(this@AIScope)

/** Runs the [agent] to enlarge the [context], and then executes the [scope]. */
@AiDsl
suspend fun <A> context(agent: ContextualAgent, scope: suspend AIScope.() -> A): A =
context(listOf(agent), scope)
suspend fun extendContext(vararg docs: String) {
context.addTexts(docs.toList())
}

/** Runs the [agents] to enlarge the [context], and then executes the [scope]. */
@AiDsl
suspend fun <A> context(agents: Collection<ContextualAgent>, scope: suspend AIScope.() -> A): A {
suspend fun extendContext(vararg agents: ContextualAgent) {
agents.forEach {
logger.debug { "[${it.name}] Running" }
val docs = with(it) { call() }
Expand All @@ -174,14 +175,49 @@ class AIScope(
logger.debug { "[${it.name}] Found no docs" }
}
}
return scope(this)
}

@AiDsl
suspend fun <A> contextScope(
store: suspend (Embeddings) -> Resource<VectorStore>,
block: AI<A>
): A {
val newStore = store(embeddings).bind()
return AIScope(
openAIClient,
CombinedVectorStore(newStore, context),
embeddings,
logger,
this,
this
)
.block()
}

@AiDsl
suspend fun <A> contextScope(block: AI<A>): A = contextScope(LocalVectorStoreBuilder, block)

/** Add new [docs] to the [context], and then executes the [scope]. */
@AiDsl
suspend fun <A> context(docs: List<String>, scope: suspend AIScope.() -> A): A {
context.addTexts(docs)
return scope(this)
suspend fun <A> contextScope(docs: List<String>, scope: suspend AIScope.() -> A): A =
contextScope {
extendContext(*docs.toTypedArray())
scope(this)
}

/** Runs the [agent] to enlarge the [context], and then executes the [scope]. */
@AiDsl
suspend fun <A> contextScope(agent: ContextualAgent, scope: suspend AIScope.() -> A): A =
contextScope(listOf(agent), scope)

/** Runs the [agents] to enlarge the [context], and then executes the [scope]. */
@AiDsl
suspend fun <A> contextScope(
agents: Collection<ContextualAgent>,
scope: suspend AIScope.() -> A
): A = contextScope {
extendContext(*agents.toTypedArray())
scope(this)
}

@AiDsl
Expand Down Expand Up @@ -306,8 +342,4 @@ class AIScope(
{ prompt(it, mapOf("url" to url.url, "prompt" to prompt), llmModel) }
)
}

@AiDsl
suspend fun <A> withContextStore(store: (Embeddings) -> Resource<VectorStore>, block: AI<A>): A =
AIScope(openAIClient, store(embeddings).bind(), embeddings, logger, this, this).block()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.xebia.functional.xef.vectorstores

import com.xebia.functional.xef.embeddings.Embedding

class CombinedVectorStore(private val top: VectorStore, private val bottom: VectorStore) :
VectorStore by top {

fun pop(): VectorStore = bottom

override suspend fun similaritySearch(query: String, limit: Int): List<String> {
val topResults = top.similaritySearch(query, limit)
return when {
topResults.size >= limit -> topResults
else -> topResults + bottom.similaritySearch(query, limit - topResults.size)
}
}

override suspend fun similaritySearchByVector(embedding: Embedding, limit: Int): List<String> {
val topResults = top.similaritySearchByVector(embedding, limit)
return when {
topResults.size >= limit -> topResults
else -> topResults + bottom.similaritySearchByVector(embedding, limit - topResults.size)
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.xebia.functional.xef.vectorstores

import arrow.fx.coroutines.Resource
import arrow.fx.coroutines.continuations.resource
import arrow.fx.stm.TMap
import arrow.fx.stm.TVar
import arrow.fx.stm.atomically
Expand All @@ -9,6 +11,10 @@ import com.xebia.functional.xef.llm.openai.EmbeddingModel
import com.xebia.functional.xef.llm.openai.RequestConfig
import kotlin.math.sqrt

val LocalVectorStoreBuilder: (Embeddings) -> Resource<LocalVectorStore> = { e ->
resource { LocalVectorStore(e) }
}

class LocalVectorStore
private constructor(
private val embeddings: Embeddings,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ suspend fun main() {
ai {
val sdf = SimpleDateFormat("dd/M/yyyy")
val currentDate = sdf.format(Date())
context(search("$currentDate Covid News")) {
contextScope(search("$currentDate Covid News")) {
val news: BreakingNewsAboutCovid =
prompt("write a paragraph of about 300 words about: $currentDate Covid News")
println(news)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ data class NumberOfMedicalNeedlesInWorld(val numberOfNeedles: Long)

suspend fun main() {
ai {
context(search("Estimate amount of medical needles in the world")) {
contextScope(search("Estimate amount of medical needles in the world")) {
val needlesInWorld: NumberOfMedicalNeedlesInWorld =
prompt("Provide the number of medical needles in the world")
println("Needles in world: ${needlesInWorld.numberOfNeedles}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ suspend fun main() {
ai {
val sdf = SimpleDateFormat("dd/M/yyyy")
val currentDate = sdf.format(Date())
context(search("$currentDate Stock market results, raising stocks, decreasing stocks")) {
contextScope(search("$currentDate Stock market results, raising stocks, decreasing stocks")) {
val news: MarketNews = prompt(
"""|
|Write a short summary of the stock market results given the provided context.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ data class MealPlan(val name: String, val recipes: List<Recipe>)

suspend fun main() {
ai {
context(search("gall bladder stones meals")) {
contextScope(search("gall bladder stones meals")) {
val mealPlan: MealPlan = prompt(
"Meal plan for the week for a person with gall bladder stones that includes 5 recipes."
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ suspend fun main() {
private suspend fun getQuestionAnswer(
question: String
): List<String> = ai {
context(search("Weather in Cádiz, Spain")) {
contextScope(search("Weather in Cádiz, Spain")) {
promptMessage(question)
}
}.getOrElse { throw IllegalStateException(it.reason) }

0 comments on commit 3ac51b2

Please sign in to comment.