Skip to content

Commit

Permalink
Port PromptTemplate (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
nomisRev authored Apr 21, 2023
1 parent 034b6f3 commit 5d218d7
Show file tree
Hide file tree
Showing 8 changed files with 363 additions and 33 deletions.
72 changes: 39 additions & 33 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ repositories {
}

plugins {
base
alias(libs.plugins.kotlin.multiplatform)
alias(libs.plugins.spotless)
alias(libs.plugins.kotlinx.serialization)
base
alias(libs.plugins.kotlin.multiplatform)
alias(libs.plugins.spotless)
alias(libs.plugins.kotlinx.serialization)
}

java {
Expand All @@ -30,42 +30,48 @@ kotlin {
}
}
js(IR) {
browser {
commonWebpackConfig {
cssSupport {
enabled.set(true)
}
}
}
browser()
nodejs()
}
val hostOs = System.getProperty("os.name")
val isMingwX64 = hostOs.startsWith("Windows")
when {
hostOs == "Mac OS X" -> macosX64("native")
hostOs == "Linux" -> linuxX64("native")
isMingwX64 -> mingwX64("native")
else -> throw GradleException("Host OS is not supported in Kotlin/Native.")
hostOs == "Mac OS X" -> macosX64("native")
hostOs == "Linux" -> linuxX64("native")
isMingwX64 -> mingwX64("native")
else -> throw GradleException("Host OS is not supported in Kotlin/Native.")
}


sourceSets {
commonMain {
dependencies {
implementation(libs.arrow.fx)
implementation(libs.kotlinx.serialization.json)
implementation(libs.bundles.ktor.client)
implementation(libs.okio)
}
}


sourceSets {
commonMain {
dependencies {
implementation(libs.arrow.fx)
implementation(libs.kotlinx.serialization.json)
implementation(libs.bundles.ktor.client)
}
}
commonTest {
dependencies {
implementation(kotlin("test"))
}
}
commonTest {
dependencies {
implementation(libs.okio.fakefilesystem)
implementation(libs.kotest.property)
implementation(libs.kotest.framework)
implementation(libs.kotest.assertions)
implementation(libs.kotest.assertions.arrow)
}
}
val jvmTest by getting {
dependencies {
implementation(libs.kotest.junit5)
}
}
}
}

spotless {
kotlin {
ktfmt().googleStyle()
}
}
kotlin {
ktfmt().googleStyle()
}
}
10 changes: 10 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ openai = "0.12.0"
kotlinx-json = "1.5.0"
ktor = "2.2.2"
spotless = "6.18.0"
okio = "3.3.0"
kotest = "5.5.4"
kotest-arrow = "1.3.0"

[libraries]
arrow-fx = { module = "io.arrow-kt:arrow-fx-coroutines", version.ref = "arrow" }
Expand All @@ -13,6 +16,13 @@ kotlinx-serialization-json = { module = "org.jetbrains.kotlinx:kotlinx-serializa
ktor-client = { module = "io.ktor:ktor-client-core", version.ref = "ktor" }
ktor-client-content-negotiation = { module = "io.ktor:ktor-client-content-negotiation", version.ref = "ktor" }
ktor-client-serialization = { module = "io.ktor:ktor-serialization-kotlinx-json", version.ref = "ktor" }
okio = { module = "com.squareup.okio:okio", version.ref = "okio" }
okio-fakefilesystem = { module = "com.squareup.okio:okio-fakefilesystem", version.ref = "okio" }
kotest-assertions = { module = "io.kotest:kotest-assertions-core", version.ref = "kotest" }
kotest-framework = { module = "io.kotest:kotest-framework-engine", version.ref = "kotest" }
kotest-property = { module = "io.kotest:kotest-property", version.ref = "kotest" }
kotest-junit5 = { module = "io.kotest:kotest-runner-junit5", version.ref = "kotest" }
kotest-assertions-arrow = { module = "io.kotest.extensions:kotest-assertions-arrow", version.ref = "kotest-arrow" }

[bundles]
ktor-client = [
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package com.xebia.functional.prompt

import arrow.core.raise.Raise
import okio.FileSystem
import okio.Path
import okio.buffer
import okio.use

fun Raise<InvalidTemplate>.PromptTemplate(
examples: List<String>,
suffix: String,
variables: List<String>,
prefix: String
): PromptTemplate {
val template = """|$prefix
|
|${examples.joinToString(separator = "\n")}
|
|$suffix""".trimMargin()
return PromptTemplate(Config(template, variables))
}

fun Raise<InvalidTemplate>.PromptTemplate(template: String, variables: List<String>): PromptTemplate =
PromptTemplate(Config(template, variables))

/**
* Creates a PromptTemplate based on a Path
* JVM & Native have overloads for FileSystem.SYSTEM,
* on NodeJs you need to manually pass FileSystem.SYSTEM.
*
* This function can currently not be used on the browser.
*
* https://github.com/square/okio/issues/1070
* https://youtrack.jetbrains.com/issue/KT-47038
*/
suspend fun Raise<InvalidTemplate>.PromptTemplate(
path: Path,
variables: List<String>,
fileSystem: FileSystem
): PromptTemplate =
fileSystem.source(path).use { source ->
source.buffer().use { buffer ->
val template = buffer.readUtf8()
val config = Config(template, variables)
PromptTemplate(config)
}
}

interface PromptTemplate {
val inputKeys: List<String>
suspend fun format(variables: Map<String, String>): String

companion object {
operator fun invoke(config: Config): PromptTemplate = object : PromptTemplate {
override val inputKeys: List<String> = config.inputVariables

override suspend fun format(variables: Map<String, String>): String {
val mergedArgs = mergePartialAndUserVariables(variables, config.inputVariables)
return when (config.templateFormat) {
TemplateFormat.FString -> {
val sortedArgs = mergedArgs.toList().sortedBy { it.first }
sortedArgs.fold(config.template) { acc, (k, v) -> acc.replace("{$k}", v) }
}
}
}

private fun mergePartialAndUserVariables(
variables: Map<String, String>,
inputVariables: List<String>
): Map<String, String> =
inputVariables.fold(variables) { acc, k ->
if (!acc.containsKey(k)) acc + (k to "{$k}") else acc
}
}
}
}
57 changes: 57 additions & 0 deletions src/commonMain/kotlin/com/xebia/functional/prompt/models.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package com.xebia.functional.prompt

import arrow.core.Either
import arrow.core.NonEmptyList
import arrow.core.raise.Raise
import arrow.core.raise.either
import arrow.core.raise.ensure
import arrow.core.raise.zipOrAccumulate

enum class TemplateFormat {
FString
}

data class InvalidTemplate(val reason: String)

fun Raise<InvalidTemplate>.Config(template: String, inputVariables: List<String>): Config =
Config.either(template, inputVariables).bind()

class Config private constructor(
val inputVariables: List<String>,
val template: String,
val templateFormat: TemplateFormat = TemplateFormat.FString
) {
companion object {
// We cannot define `operator fun invoke` with `Raise` without context receivers,
// so we define an intermediate `Either` based function.
// This is because adding `Raise<InvalidTemplate>` results in 2 receivers.
fun either(template: String, variables: List<String>): Either<InvalidTemplate, Config> =
either<NonEmptyList<InvalidTemplate>, Config> {
val placeholders = placeholderValues(template)

zipOrAccumulate(
{ validate(template, variables.toSet() - placeholders.toSet(), "unused") },
{ validate(template, placeholders.toSet() - variables.toSet(), "missing") },
{ validateDuplicated(template, placeholders) }
) { _, _, _ -> Config(variables, template) }
}.mapLeft { InvalidTemplate(it.joinToString(transform = InvalidTemplate::reason)) }
}
}

private fun Raise<InvalidTemplate>.validate(template: String, diffSet: Set<String>, msg: String): Unit =
ensure(diffSet.isEmpty()) {
InvalidTemplate("Template '$template' has $msg arguments: ${diffSet.joinToString(", ") { "{$it}" }}")
}

private fun Raise<InvalidTemplate>.validateDuplicated(template: String, placeholders: List<String>) {
val args = placeholders.groupBy { it }.filter { it.value.size > 1 }.keys
ensure(args.isEmpty()) {
InvalidTemplate("Template '$template' has duplicate arguments: ${args.joinToString(", ") { "{$it}" }}")
}
}

private fun placeholderValues(template: String): List<String> {
@Suppress("RegExpRedundantEscape")
val regex = Regex("""\{([^\{\}]+)\}""")
return regex.findAll(template).toList().mapNotNull { it.groupValues.getOrNull(1) }
}
59 changes: 59 additions & 0 deletions src/commonTest/kotlin/com/xebia/functional/prompt/ConfigSpec.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package com.xebia.functional.prompt

import arrow.core.raise.either
import io.kotest.assertions.arrow.core.shouldBeLeft
import io.kotest.assertions.arrow.core.shouldBeRight
import io.kotest.core.spec.style.StringSpec
import io.kotest.matchers.shouldBe

class ConfigSpec : StringSpec({

"should return a valid Config if the template and input variables are valid" {
val template = "Hello {name}, you are {age} years old."
val variables = listOf("name", "age")

val config = either { Config(template, variables) }.shouldBeRight()

config.inputVariables shouldBe variables
config.template shouldBe template
config.templateFormat shouldBe TemplateFormat.FString
}

"should fail with a InvalidTemplateError if the template has missing arguments" {
val template = "Hello {name}, you are {age} years old."
val variables = listOf("name")

either {
Config(template, variables)
} shouldBeLeft InvalidTemplate("Template 'Hello {name}, you are {age} years old.' has missing arguments: {age}")
}

"should fail with a InvalidTemplateError if the template has unused arguments" {
val template = "Hello {name}, you are {age} years old."
val variables = listOf("name", "age", "unused")

either {
Config(template, variables)
} shouldBeLeft InvalidTemplate("Template 'Hello {name}, you are {age} years old.' has unused arguments: {unused}")
}

"should fail with a InvalidTemplateError if there are duplicate input variables" {
val template = "Hello {name}, you are {name} years old."
val variables = listOf("name")

either {
Config(template, variables)
} shouldBeLeft InvalidTemplate("Template 'Hello {name}, you are {name} years old.' has duplicate arguments: {name}")
}

"should fail with a combination of InvalidTemplateErrors if there are multiple things wrong" {
val template = "Hello {name}, you are {name} years old."
val variables = listOf("name", "age")
val unused = "Template 'Hello {name}, you are {name} years old.' has unused arguments: {age}"
val duplicated = "Template 'Hello {name}, you are {name} years old.' has duplicate arguments: {name}"

either {
Config(template, variables)
} shouldBeLeft InvalidTemplate("$unused, $duplicated")
}
})
Loading

0 comments on commit 5d218d7

Please sign in to comment.