Skip to content

Commit

Permalink
SequenceChain Implementation (#29)
Browse files Browse the repository at this point in the history
* Adding SequenceChain as a Chain

* SequenceChain implementation

* Update kotlin/src/commonMain/kotlin/com/xebia/functional/chains/SequenceChain.kt

Co-authored-by: Simon Vergauwen <nomisRev@users.noreply.github.com>

* Update kotlin/src/commonMain/kotlin/com/xebia/functional/chains/SequenceChain.kt

Co-authored-by: Simon Vergauwen <nomisRev@users.noreply.github.com>

* Some requested changes

---------

Co-authored-by: yago <Yawolf@users.noreply.github.com>
Co-authored-by: Simon Vergauwen <nomisRev@users.noreply.github.com>
  • Loading branch information
3 people authored May 9, 2023
1 parent a610e08 commit 7b29ab5
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ interface Chain {

data class InvalidInputs(override val reason: String): Error(reason)

data class InvalidOutputs(override val reason: String): Error(reason)

data class Config(
val inputKeys: Set<String>,
val outputKeys: Set<String>,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,91 @@
package com.xebia.functional.chains

interface SequenceChain : Chain {
data class InvalidOutputs(override val reason: String): Chain.Error(reason)
data class InvalidKeys(override val reason: String): Chain.Error(reason)
import arrow.core.Either
import arrow.core.flatten
import arrow.core.raise.either
import arrow.core.raise.ensure
import arrow.core.raise.Raise
import arrow.core.raise.recover
import arrow.core.raise.zipOrAccumulate
import arrow.core.raise.mapOrAccumulate

fun Raise<Chain.Error>.SequenceChain(
chains: List<Chain>,
inputVariables: List<String>,
outputVariables: List<String>,
chainOutput: Chain.ChainOutput = Chain.ChainOutput.OnlyOutput
): SequenceChain =
SequenceChain.either(chains, inputVariables, outputVariables, chainOutput).bind()

open class SequenceChain(
private val chains: List<Chain>,
private val inputVariables: List<String>,
private val outputVariables: List<String>,
chainOutput: Chain.ChainOutput = Chain.ChainOutput.OnlyOutput
) : Chain {
data class InvalidOutputs(override val reason: String) : Chain.Error(reason)
data class InvalidKeys(override val reason: String) : Chain.Error(reason)

override val config = Chain.Config(inputVariables.toSet(), outputVariables.toSet(), chainOutput)

private val outputs = when (chainOutput) {
Chain.ChainOutput.OnlyOutput -> outputVariables
Chain.ChainOutput.InputAndOutput -> outputVariables.plus(inputVariables)
}

override suspend fun call(inputs: Map<String, String>): Either<Chain.Error, Map<String, String>> =
either {
val chainRes = chains.fold(inputs) { inputs0, chain ->
chain.run(inputs0).map { inputs0 + it }.bind()
}
chainRes.filter { it.key in outputs }
}

companion object {
fun either(
chains: List<Chain>,
inputVariables: List<String>,
outputVariables: List<String>,
chainOutput: Chain.ChainOutput
): Either<InvalidKeys, SequenceChain> =
either {
val allOutputs = chains.map { it.config.outputKeys }.toSet().flatten()
val mappedChains: List<Chain> = recover({
mapOrAccumulate(chains) { chain ->
zipOrAccumulate(
{ validateSequenceOutputs(outputVariables, allOutputs) },
{ validateInputsOverlapping(inputVariables, allOutputs) },
) { _, _ -> chain }
}
}) { raise(InvalidKeys(reason = it.flatten().joinToString(transform = Chain.Error::reason))) }
SequenceChain(mappedChains, inputVariables, outputVariables, chainOutput)
}
}
}

private fun Raise<Chain.InvalidOutputs>.validateSequenceOutputs(
sequenceOutputs: List<String>,
chainOutputs: List<String>
): Unit =
ensure(sequenceOutputs.isNotEmpty() && sequenceOutputs.all { it in chainOutputs }) {
Chain.InvalidOutputs("The provided outputs: " +
sequenceOutputs.joinToString(", ") { "{$it}" } +
" do not exist in chains' outputs: " +
chainOutputs.joinToString { "{$it}" }
)
}

private fun Raise<Chain.InvalidInputs>.validateInputsOverlapping(
sequenceInputs: List<String>,
chainOutputs: List<String>
): Unit =
ensure(sequenceInputs.isNotEmpty() && sequenceInputs.all { it !in chainOutputs }) {
Chain.InvalidInputs("The provided inputs: " +
sequenceInputs.joinToString { "{$it}" } +
" overlap with chain's outputs: " +
chainOutputs.joinToString { "{$it}" }

)
}


Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
package com.xebia.functional.chains

import arrow.core.Either
import arrow.core.raise.Raise
import arrow.core.raise.either
import arrow.core.raise.ensure
import arrow.core.raise.recover
import arrow.core.raise.zipOrAccumulate
import arrow.core.raise.*

fun Raise<Chain.Error>.SimpleSequenceChain(
chains: List<Chain>,
Expand All @@ -20,7 +16,7 @@ class SimpleSequenceChain private constructor(
private val inputKey: String,
private val outputKey: String,
chainOutput: Chain.ChainOutput
) : SequenceChain {
) : SequenceChain(chains, listOf(inputKey), listOf(outputKey), chainOutput) {

override val config = Chain.Config(setOf(inputKey), setOf(outputKey), chainOutput)

Expand All @@ -40,14 +36,14 @@ class SimpleSequenceChain private constructor(
inputKey: String,
outputKey: String,
chainOutput: Chain.ChainOutput
): Either<SequenceChain.InvalidKeys, SimpleSequenceChain> =
): Either<InvalidKeys, SimpleSequenceChain> =
either {
val mappedChains: List<Chain> = chains.map { chain ->
recover({
zipOrAccumulate(
{ validateInputKeys(chain.config.inputKeys) },
{ validateOutputKeys(chain.config.outputKeys) }) { _, _ -> chain }
}) { raise(SequenceChain.InvalidKeys(it.joinToString(transform = Chain.Error::reason))) }
}) { raise(InvalidKeys(reason = it.joinToString(transform = Chain.Error::reason))) }
}
SimpleSequenceChain(mappedChains, inputKey, outputKey, chainOutput)
}
Expand All @@ -65,3 +61,4 @@ private fun Raise<Chain.InvalidInputs>.validateInputKeys(inputKeys: Set<String>)
Chain.InvalidInputs("The expected inputs are more than one: " +
inputKeys.joinToString(", ") { "{$it}" })
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package com.xebia.functional.chains

import arrow.core.raise.either
import io.kotest.assertions.arrow.core.shouldBeLeft
import io.kotest.assertions.arrow.core.shouldBeRight
import io.kotest.core.spec.style.StringSpec

class SequenceChainSpec : StringSpec({
"SequenceChain should return a prediction with one Chain" {
val chain1 = FakeChain(inputVariables = setOf("foo"), outputVariables = setOf("bar"))
val chains = listOf(chain1)

either {
val sc = SequenceChain(
chains = chains,
inputVariables = listOf("foo"),
outputVariables = listOf("bar"),
chainOutput = Chain.ChainOutput.InputAndOutput
)
sc.run(mapOf("foo" to "123")).bind()
} shouldBeRight mapOf("foo" to "123", "bar" to "123dr")
}

"SequenceChain should return a prediction on a single input chain" {
val chain1 = FakeChain(inputVariables = setOf("foo"), outputVariables = setOf("bar"))
val chain2 = FakeChain(inputVariables = setOf("bar"), outputVariables = setOf("baz"))
val chains = listOf(chain1, chain2)

either {
val sc = SequenceChain(
chains = chains,
inputVariables = listOf("foo"),
outputVariables = listOf("baz"),
chainOutput = Chain.ChainOutput.InputAndOutput
)
sc.run(mapOf("foo" to "123")).bind()
} shouldBeRight mapOf("foo" to "123", "baz" to "123drdr")
}

"SequenceChain should return a prediction on a multiple input chain" {
val chain1 = FakeChain(inputVariables = setOf("foo", "test"), outputVariables = setOf("bar"))
val chain2 = FakeChain(inputVariables = setOf("bar", "foo"), outputVariables = setOf("baz"))
val chains = listOf(chain1, chain2)

either {
val sc = SequenceChain(
chains = chains,
inputVariables = listOf("foo", "test"),
outputVariables = listOf("baz"),
chainOutput = Chain.ChainOutput.InputAndOutput
)
sc.run(mapOf("foo" to "123", "test" to "456")).bind()
} shouldBeRight mapOf("foo" to "123", "test" to "456", "baz" to "123456dr123dr")
}

"SequenceChain should return a prediction on a multiple output chain" {
val chain1 = FakeChain(inputVariables = setOf("foo"), outputVariables = setOf("bar", "test"))
val chain2 = FakeChain(inputVariables = setOf("bar", "foo"), outputVariables = setOf("baz"))
val chains = listOf(chain1, chain2)

either {
val sc = SequenceChain(
chains = chains,
inputVariables = listOf("foo"),
outputVariables = listOf("baz"),
chainOutput = Chain.ChainOutput.InputAndOutput
)
sc.run(mapOf("foo" to "123")).bind()
} shouldBeRight mapOf("foo" to "123", "baz" to "123dr123dr")
}

"SequenceChain should fail when input variables are missing" {
val chain1 = FakeChain(inputVariables = setOf("foo"), outputVariables = setOf("bar"))
val chain2 = FakeChain(inputVariables = setOf("bar", "test"), outputVariables = setOf("baz"))
val chains = listOf(chain1, chain2)

either {
val sc = SequenceChain(
chains = chains,
inputVariables = listOf("foo"),
outputVariables = listOf("baz"),
chainOutput = Chain.ChainOutput.InputAndOutput
)
sc.run(mapOf("foo" to "123")).bind()
} shouldBeLeft Chain.InvalidInputs("The provided inputs: {foo}, {bar} do not match with chain's inputs: {bar}, {test}")
}

"SequenceChain should fail when output variables are missing" {
val chain1 = FakeChain(inputVariables = setOf("foo"), outputVariables = setOf("bar"))
val chain2 = FakeChain(inputVariables = setOf("bar"), outputVariables = setOf("baz"))
val chains = listOf(chain1, chain2)

either {
val sc = SequenceChain.either(
chains = chains,
inputVariables = listOf("foo"),
outputVariables = listOf("test"),
chainOutput = Chain.ChainOutput.InputAndOutput
).bind()
sc.run(mapOf("foo" to "123")).bind()
} shouldBeLeft SequenceChain.InvalidKeys("The provided outputs: {test} do not exist in chains' outputs: {bar}, {baz}")
}

"SequenceChain should fail when input variables are overlapping" {
val chain1 = FakeChain(inputVariables = setOf("foo"), outputVariables = setOf("bar", "test"))
val chain2 = FakeChain(inputVariables = setOf("bar"), outputVariables = setOf("baz"))
val chains = listOf(chain1, chain2)

either {
val sc = SequenceChain.either(
chains = chains,
inputVariables = listOf("foo", "test"),
outputVariables = listOf("baz"),
chainOutput = Chain.ChainOutput.InputAndOutput
).bind()
sc.run(mapOf("foo" to "123")).bind()
} shouldBeLeft SequenceChain.InvalidKeys("The provided inputs: {foo}, {test} overlap with chain's outputs: {bar}, {test}, {baz}")
}
})
2 changes: 1 addition & 1 deletion scala/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ plugins {
}

dependencies {
implementation(projects.langchain4kKotlin)
//implementation(projects.langchain4kKotlin)
implementation(libs.kotlinx.coroutines)
implementation(libs.ciris.core)
implementation(libs.ciris.refined)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class SequentialChainSpec extends CatsEffectSuite:
interceptIO[MissingOutputVariablesError](output)
}

test("Test SequentialChainruns when valid outputs are specified.") {
test("Test SequentialChain runs when valid outputs are specified.") {
val chain1 = FakeChain(inputVariables = Set("foo"), outputVariables = Set("bar"))
val chain2 = FakeChain(inputVariables = Set("bar"), outputVariables = Set("baz"))
val chains = NonEmptySeq(chain1, Seq(chain2))
Expand All @@ -105,7 +105,7 @@ class SequentialChainSpec extends CatsEffectSuite:
assertIO(output, expectedOutput)
}

test("Test SequentialChainruns error is raised when input variables are overlapping.") {
test("Test SequentialChain runs error is raised when input variables are overlapping.") {
val chain1 = FakeChain(inputVariables = Set("foo"), outputVariables = Set("bar", "test"))
val chain2 = FakeChain(inputVariables = Set("bar"), outputVariables = Set("baz"))
val chains = NonEmptySeq(chain1, Seq(chain2))
Expand Down

0 comments on commit 7b29ab5

Please sign in to comment.