diff --git a/kotlin/src/commonMain/kotlin/com/xebia/functional/chains/Chain.kt b/kotlin/src/commonMain/kotlin/com/xebia/functional/chains/Chain.kt index a0a9e730e..d519f3d09 100644 --- a/kotlin/src/commonMain/kotlin/com/xebia/functional/chains/Chain.kt +++ b/kotlin/src/commonMain/kotlin/com/xebia/functional/chains/Chain.kt @@ -14,6 +14,8 @@ interface Chain { data class InvalidInputs(override val reason: String): Error(reason) + data class InvalidOutputs(override val reason: String): Error(reason) + data class Config( val inputKeys: Set, val outputKeys: Set, diff --git a/kotlin/src/commonMain/kotlin/com/xebia/functional/chains/SequenceChain.kt b/kotlin/src/commonMain/kotlin/com/xebia/functional/chains/SequenceChain.kt index 4c078b49e..47b2a890f 100644 --- a/kotlin/src/commonMain/kotlin/com/xebia/functional/chains/SequenceChain.kt +++ b/kotlin/src/commonMain/kotlin/com/xebia/functional/chains/SequenceChain.kt @@ -1,6 +1,91 @@ package com.xebia.functional.chains -interface SequenceChain : Chain { - data class InvalidOutputs(override val reason: String): Chain.Error(reason) - data class InvalidKeys(override val reason: String): Chain.Error(reason) +import arrow.core.Either +import arrow.core.flatten +import arrow.core.raise.either +import arrow.core.raise.ensure +import arrow.core.raise.Raise +import arrow.core.raise.recover +import arrow.core.raise.zipOrAccumulate +import arrow.core.raise.mapOrAccumulate + +fun Raise.SequenceChain( + chains: List, + inputVariables: List, + outputVariables: List, + chainOutput: Chain.ChainOutput = Chain.ChainOutput.OnlyOutput +): SequenceChain = + SequenceChain.either(chains, inputVariables, outputVariables, chainOutput).bind() + +open class SequenceChain( + private val chains: List, + private val inputVariables: List, + private val outputVariables: List, + chainOutput: Chain.ChainOutput = Chain.ChainOutput.OnlyOutput +) : Chain { + data class InvalidOutputs(override val reason: String) : Chain.Error(reason) + data class InvalidKeys(override val reason: String) : Chain.Error(reason) + + override val config = Chain.Config(inputVariables.toSet(), outputVariables.toSet(), chainOutput) + + private val outputs = when (chainOutput) { + Chain.ChainOutput.OnlyOutput -> outputVariables + Chain.ChainOutput.InputAndOutput -> outputVariables.plus(inputVariables) + } + + override suspend fun call(inputs: Map): Either> = + either { + val chainRes = chains.fold(inputs) { inputs0, chain -> + chain.run(inputs0).map { inputs0 + it }.bind() + } + chainRes.filter { it.key in outputs } + } + + companion object { + fun either( + chains: List, + inputVariables: List, + outputVariables: List, + chainOutput: Chain.ChainOutput + ): Either = + either { + val allOutputs = chains.map { it.config.outputKeys }.toSet().flatten() + val mappedChains: List = recover({ + mapOrAccumulate(chains) { chain -> + zipOrAccumulate( + { validateSequenceOutputs(outputVariables, allOutputs) }, + { validateInputsOverlapping(inputVariables, allOutputs) }, + ) { _, _ -> chain } + } + }) { raise(InvalidKeys(reason = it.flatten().joinToString(transform = Chain.Error::reason))) } + SequenceChain(mappedChains, inputVariables, outputVariables, chainOutput) + } + } } + +private fun Raise.validateSequenceOutputs( + sequenceOutputs: List, + chainOutputs: List +): Unit = + ensure(sequenceOutputs.isNotEmpty() && sequenceOutputs.all { it in chainOutputs }) { + Chain.InvalidOutputs("The provided outputs: " + + sequenceOutputs.joinToString(", ") { "{$it}" } + + " do not exist in chains' outputs: " + + chainOutputs.joinToString { "{$it}" } + ) + } + +private fun Raise.validateInputsOverlapping( + sequenceInputs: List, + chainOutputs: List +): Unit = + ensure(sequenceInputs.isNotEmpty() && sequenceInputs.all { it !in chainOutputs }) { + Chain.InvalidInputs("The provided inputs: " + + sequenceInputs.joinToString { "{$it}" } + + " overlap with chain's outputs: " + + chainOutputs.joinToString { "{$it}" } + + ) + } + + diff --git a/kotlin/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt b/kotlin/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt index cbe6094fa..a64601a6f 100644 --- a/kotlin/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt +++ b/kotlin/src/commonMain/kotlin/com/xebia/functional/chains/SimpleSequenceChain.kt @@ -1,11 +1,7 @@ package com.xebia.functional.chains import arrow.core.Either -import arrow.core.raise.Raise -import arrow.core.raise.either -import arrow.core.raise.ensure -import arrow.core.raise.recover -import arrow.core.raise.zipOrAccumulate +import arrow.core.raise.* fun Raise.SimpleSequenceChain( chains: List, @@ -20,7 +16,7 @@ class SimpleSequenceChain private constructor( private val inputKey: String, private val outputKey: String, chainOutput: Chain.ChainOutput -) : SequenceChain { +) : SequenceChain(chains, listOf(inputKey), listOf(outputKey), chainOutput) { override val config = Chain.Config(setOf(inputKey), setOf(outputKey), chainOutput) @@ -40,14 +36,14 @@ class SimpleSequenceChain private constructor( inputKey: String, outputKey: String, chainOutput: Chain.ChainOutput - ): Either = + ): Either = either { val mappedChains: List = chains.map { chain -> recover({ zipOrAccumulate( { validateInputKeys(chain.config.inputKeys) }, { validateOutputKeys(chain.config.outputKeys) }) { _, _ -> chain } - }) { raise(SequenceChain.InvalidKeys(it.joinToString(transform = Chain.Error::reason))) } + }) { raise(InvalidKeys(reason = it.joinToString(transform = Chain.Error::reason))) } } SimpleSequenceChain(mappedChains, inputKey, outputKey, chainOutput) } @@ -65,3 +61,4 @@ private fun Raise.validateInputKeys(inputKeys: Set) Chain.InvalidInputs("The expected inputs are more than one: " + inputKeys.joinToString(", ") { "{$it}" }) } + diff --git a/kotlin/src/commonTest/kotlin/com/xebia/functional/chains/SequenceChainSpec.kt b/kotlin/src/commonTest/kotlin/com/xebia/functional/chains/SequenceChainSpec.kt new file mode 100644 index 000000000..1b9337240 --- /dev/null +++ b/kotlin/src/commonTest/kotlin/com/xebia/functional/chains/SequenceChainSpec.kt @@ -0,0 +1,119 @@ +package com.xebia.functional.chains + +import arrow.core.raise.either +import io.kotest.assertions.arrow.core.shouldBeLeft +import io.kotest.assertions.arrow.core.shouldBeRight +import io.kotest.core.spec.style.StringSpec + +class SequenceChainSpec : StringSpec({ + "SequenceChain should return a prediction with one Chain" { + val chain1 = FakeChain(inputVariables = setOf("foo"), outputVariables = setOf("bar")) + val chains = listOf(chain1) + + either { + val sc = SequenceChain( + chains = chains, + inputVariables = listOf("foo"), + outputVariables = listOf("bar"), + chainOutput = Chain.ChainOutput.InputAndOutput + ) + sc.run(mapOf("foo" to "123")).bind() + } shouldBeRight mapOf("foo" to "123", "bar" to "123dr") + } + + "SequenceChain should return a prediction on a single input chain" { + val chain1 = FakeChain(inputVariables = setOf("foo"), outputVariables = setOf("bar")) + val chain2 = FakeChain(inputVariables = setOf("bar"), outputVariables = setOf("baz")) + val chains = listOf(chain1, chain2) + + either { + val sc = SequenceChain( + chains = chains, + inputVariables = listOf("foo"), + outputVariables = listOf("baz"), + chainOutput = Chain.ChainOutput.InputAndOutput + ) + sc.run(mapOf("foo" to "123")).bind() + } shouldBeRight mapOf("foo" to "123", "baz" to "123drdr") + } + + "SequenceChain should return a prediction on a multiple input chain" { + val chain1 = FakeChain(inputVariables = setOf("foo", "test"), outputVariables = setOf("bar")) + val chain2 = FakeChain(inputVariables = setOf("bar", "foo"), outputVariables = setOf("baz")) + val chains = listOf(chain1, chain2) + + either { + val sc = SequenceChain( + chains = chains, + inputVariables = listOf("foo", "test"), + outputVariables = listOf("baz"), + chainOutput = Chain.ChainOutput.InputAndOutput + ) + sc.run(mapOf("foo" to "123", "test" to "456")).bind() + } shouldBeRight mapOf("foo" to "123", "test" to "456", "baz" to "123456dr123dr") + } + + "SequenceChain should return a prediction on a multiple output chain" { + val chain1 = FakeChain(inputVariables = setOf("foo"), outputVariables = setOf("bar", "test")) + val chain2 = FakeChain(inputVariables = setOf("bar", "foo"), outputVariables = setOf("baz")) + val chains = listOf(chain1, chain2) + + either { + val sc = SequenceChain( + chains = chains, + inputVariables = listOf("foo"), + outputVariables = listOf("baz"), + chainOutput = Chain.ChainOutput.InputAndOutput + ) + sc.run(mapOf("foo" to "123")).bind() + } shouldBeRight mapOf("foo" to "123", "baz" to "123dr123dr") + } + + "SequenceChain should fail when input variables are missing" { + val chain1 = FakeChain(inputVariables = setOf("foo"), outputVariables = setOf("bar")) + val chain2 = FakeChain(inputVariables = setOf("bar", "test"), outputVariables = setOf("baz")) + val chains = listOf(chain1, chain2) + + either { + val sc = SequenceChain( + chains = chains, + inputVariables = listOf("foo"), + outputVariables = listOf("baz"), + chainOutput = Chain.ChainOutput.InputAndOutput + ) + sc.run(mapOf("foo" to "123")).bind() + } shouldBeLeft Chain.InvalidInputs("The provided inputs: {foo}, {bar} do not match with chain's inputs: {bar}, {test}") + } + + "SequenceChain should fail when output variables are missing" { + val chain1 = FakeChain(inputVariables = setOf("foo"), outputVariables = setOf("bar")) + val chain2 = FakeChain(inputVariables = setOf("bar"), outputVariables = setOf("baz")) + val chains = listOf(chain1, chain2) + + either { + val sc = SequenceChain.either( + chains = chains, + inputVariables = listOf("foo"), + outputVariables = listOf("test"), + chainOutput = Chain.ChainOutput.InputAndOutput + ).bind() + sc.run(mapOf("foo" to "123")).bind() + } shouldBeLeft SequenceChain.InvalidKeys("The provided outputs: {test} do not exist in chains' outputs: {bar}, {baz}") + } + + "SequenceChain should fail when input variables are overlapping" { + val chain1 = FakeChain(inputVariables = setOf("foo"), outputVariables = setOf("bar", "test")) + val chain2 = FakeChain(inputVariables = setOf("bar"), outputVariables = setOf("baz")) + val chains = listOf(chain1, chain2) + + either { + val sc = SequenceChain.either( + chains = chains, + inputVariables = listOf("foo", "test"), + outputVariables = listOf("baz"), + chainOutput = Chain.ChainOutput.InputAndOutput + ).bind() + sc.run(mapOf("foo" to "123")).bind() + } shouldBeLeft SequenceChain.InvalidKeys("The provided inputs: {foo}, {test} overlap with chain's outputs: {bar}, {test}, {baz}") + } +}) \ No newline at end of file diff --git a/scala/build.gradle.kts b/scala/build.gradle.kts index e92c28617..8744efeaf 100644 --- a/scala/build.gradle.kts +++ b/scala/build.gradle.kts @@ -7,7 +7,7 @@ plugins { } dependencies { - implementation(projects.langchain4kKotlin) + //implementation(projects.langchain4kKotlin) implementation(libs.kotlinx.coroutines) implementation(libs.ciris.core) implementation(libs.ciris.refined) diff --git a/scala/src/test/scala/com/xebia/functional/chains/SequentialChainSpec.scala b/scala/src/test/scala/com/xebia/functional/chains/SequentialChainSpec.scala index 0ea5c6946..977a4de9c 100644 --- a/scala/src/test/scala/com/xebia/functional/chains/SequentialChainSpec.scala +++ b/scala/src/test/scala/com/xebia/functional/chains/SequentialChainSpec.scala @@ -89,7 +89,7 @@ class SequentialChainSpec extends CatsEffectSuite: interceptIO[MissingOutputVariablesError](output) } - test("Test SequentialChainruns when valid outputs are specified.") { + test("Test SequentialChain runs when valid outputs are specified.") { val chain1 = FakeChain(inputVariables = Set("foo"), outputVariables = Set("bar")) val chain2 = FakeChain(inputVariables = Set("bar"), outputVariables = Set("baz")) val chains = NonEmptySeq(chain1, Seq(chain2)) @@ -105,7 +105,7 @@ class SequentialChainSpec extends CatsEffectSuite: assertIO(output, expectedOutput) } - test("Test SequentialChainruns error is raised when input variables are overlapping.") { + test("Test SequentialChain runs error is raised when input variables are overlapping.") { val chain1 = FakeChain(inputVariables = Set("foo"), outputVariables = Set("bar", "test")) val chain2 = FakeChain(inputVariables = Set("bar"), outputVariables = Set("baz")) val chains = NonEmptySeq(chain1, Seq(chain2))