Skip to content

Commit

Permalink
ScenarioRunner: enrich incomplete transactions (#11384)
Browse files Browse the repository at this point in the history
* ScenarioRunner: enrich incomplete transactions

fixes #11352

CHANGELOG_BEGIN
CHANGELOG_END
  • Loading branch information
remyhaemmerle-da authored Oct 26, 2021
1 parent d9c7031 commit c8006b8
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@ package engine
import com.daml.lf.data.Ref.{Identifier, Name, PackageId}
import com.daml.lf.language.{Ast, LookupError}
import com.daml.lf.transaction.Node.{GenNode, KeyWithMaintainers}
import com.daml.lf.transaction.{Node, NodeId, VersionedTransaction}
import com.daml.lf.transaction.{
IncompleteTransaction,
GenTransaction,
Node,
NodeId,
VersionedTransaction,
}
import com.daml.lf.value.Value
import com.daml.lf.value.Value.VersionedValue
import com.daml.lf.speedy.SValue
Expand Down Expand Up @@ -169,7 +175,7 @@ final class ValueEnricher(
} yield exe.copy(chosenValue = choiceArg, exerciseResult = result, key = key)
}

def enrichTransaction(tx: VersionedTransaction): Result[VersionedTransaction] = {
def enrichTransaction(tx: GenTransaction): Result[GenTransaction] =
for {
normalizedNodes <-
tx.nodes.foldLeft[Result[Map[NodeId, GenNode]]](ResultDone(Map.empty)) {
Expand All @@ -179,11 +185,21 @@ final class ValueEnricher(
normalizedNode <- enrichNode(node)
} yield nodes.updated(nid, normalizedNode)
}
} yield VersionedTransaction(
version = tx.version,
} yield GenTransaction(
nodes = normalizedNodes,
roots = tx.roots,
)
}

def enrichVersionedTransaction(versionedTx: VersionedTransaction): Result[VersionedTransaction] =
enrichTransaction(GenTransaction(versionedTx.nodes, versionedTx.roots)).map {
case GenTransaction(nodes, roots) =>
VersionedTransaction(versionedTx.version, nodes, roots)
}

def enrichIncompleteTransaction(
incompleteTx: IncompleteTransaction
): Result[IncompleteTransaction] =
enrichTransaction(incompleteTx.transaction).map(transaction =>
incompleteTx.copy(transaction = transaction)
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ class LargeTransactionTest extends AnyWordSpec with Matchers with BazelRunfiles
case ResultDone(x) => x
case x => fail(s"unexpected Result when enriching value: $x")
}
SubmittedTransaction(consume(enricher.enrichTransaction(tx)))
SubmittedTransaction(consume(enricher.enrichVersionedTransaction(tx)))
}
engine
.submit(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,9 @@ class ValueEnricherSpec extends AnyWordSpec with Matchers with TableDrivenProper
outputRecord,
)

enricher.enrichTransaction(inputTransaction) shouldNot be(ResultDone(inputTransaction))
enricher.enrichTransaction(inputTransaction) shouldBe ResultDone(outputTransaction)
enricher.enrichVersionedTransaction(inputTransaction) shouldNot
be(ResultDone(inputTransaction))
enricher.enrichVersionedTransaction(inputTransaction) shouldBe ResultDone(outputTransaction)

}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,8 @@ private[lf] object PartialTransaction {
type Node = Node.GenNode
type LeafNode = Node.LeafOnlyActionNode

private type TX = GenTransaction
private type ExerciseNode = Node.NodeExercises

private final case class IncompleteTxImpl(
val transaction: TX,
val locationInfo: Map[NodeId, Location],
) extends transaction.IncompleteTransaction

sealed abstract class ContextInfo {
val actionChildSeed: Int => crypto.Hash
def authorizers: Set[Party]
Expand Down Expand Up @@ -395,7 +389,7 @@ private[speedy] case class PartialTransaction(

val ptx = unwind()

IncompleteTxImpl(
transaction.IncompleteTransaction(
GenTransaction(
ptx.nodes,
ptx.context.children.toImmArray.toSeq.sortBy(_.index).toImmArray,
Expand Down
8 changes: 2 additions & 6 deletions daml-lf/scenario-interpreter/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ da_scala_test_suite(
name = "scenario-interpreter_tests",
size = "small",
srcs = glob(["src/test/**/*.scala"]),
scala_deps = [
"@maven//:org_scalaz_scalaz_core",
],
scala_deps = ["@maven//:org_scalaz_scalaz_core"],
scalacopts = lf_scalacopts,
deps = [
":scenario-interpreter",
Expand All @@ -71,9 +69,7 @@ da_scala_benchmark_jmh(
":CollectAuthority.dar",
":CollectAuthority.dar.pp",
],
scala_deps = [
"@maven//:org_scalaz_scalaz_core",
],
scala_deps = ["@maven//:org_scalaz_scalaz_core"],
visibility = ["//visibility:public"],
deps = [
"//bazel_tools/runfiles:scala_runfiles",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,44 @@ object ScenarioRunner {
}
}

private[this] abstract class Enricher {
def enrich(tx: SubmittedTransaction): SubmittedTransaction
def enrich(tx: IncompleteTransaction): IncompleteTransaction
}

private[this] object NoEnricher extends Enricher {
override def enrich(tx: SubmittedTransaction): SubmittedTransaction = tx
override def enrich(tx: IncompleteTransaction): IncompleteTransaction = tx
}

private[this] class EnricherImpl(compiledPackages: CompiledPackages) extends Enricher {
val config = Engine.DevEngine().config
val valueTranslator =
new ValueTranslator(
interface = compiledPackages.interface,
forbidV0ContractId = config.forbidV0ContractId,
requireV1ContractIdSuffix = config.requireSuffixedGlobalContractId,
)
def translateValue(typ: Ast.Type, value: Value): Result[SValue] =
valueTranslator.translateValue(typ, value) match {
case Left(err) => ResultError(err)
case Right(sv) => ResultDone(sv)
}
def loadPackage(pkgId: PackageId, context: language.Reference): Result[Unit] = {
crash(LookupError.MissingPackage.pretty(pkgId, context))
}
val enricher = new ValueEnricher(compiledPackages, translateValue, loadPackage)
def consume[V](res: Result[V]): V =
res match {
case ResultDone(x) => x
case x => crash(s"unexpected Result when enriching value: $x")
}
override def enrich(tx: SubmittedTransaction): SubmittedTransaction =
SubmittedTransaction(consume(enricher.enrichVersionedTransaction(tx)))
override def enrich(tx: IncompleteTransaction): IncompleteTransaction =
consume(enricher.enrichIncompleteTransaction(tx))
}

def submit[R](
compiledPackages: CompiledPackages,
ledger: LedgerApi[R],
Expand All @@ -398,53 +436,29 @@ object ScenarioRunner {
commitLocation = location,
)
val onLedger = ledgerMachine.withOnLedger(NameOf.qualifiedNameOfCurrentFunc)(identity)

def enrich(tx: SubmittedTransaction): SubmittedTransaction = {
val config = Engine.DevEngine().config
val valueTranslator =
new ValueTranslator(
interface = compiledPackages.interface,
forbidV0ContractId = config.forbidV0ContractId,
requireV1ContractIdSuffix = config.requireSuffixedGlobalContractId,
)
def translateValue(typ: Ast.Type, value: Value): Result[SValue] =
valueTranslator.translateValue(typ, value) match {
case Left(err) => ResultError(err)
case Right(sv) => ResultDone(sv)
}
def loadPackage(pkgId: PackageId, context: language.Reference): Result[Unit] = {
crash(LookupError.MissingPackage.pretty(pkgId, context))
}
val enricher = new ValueEnricher(compiledPackages, translateValue, loadPackage)
def consume[V](res: Result[V]): V =
res match {
case ResultDone(x) => x
case x => crash(s"unexpected Result when enriching value: $x")
}
SubmittedTransaction(consume(enricher.enrichTransaction(tx)))
}
val enricher = if (doEnrichment) new EnricherImpl(compiledPackages) else NoEnricher
import enricher._

@tailrec
def go(): SubmissionResult[R] = {
ledgerMachine.run() match {
case SResult.SResultFinalValue(resultValue) =>
onLedger.ptxInternal.finish match {
case PartialTransaction.CompleteTransaction(tx0, locationInfo, _) =>
val tx = if (doEnrichment) enrich(tx0) else tx0
ledger.commit(committers, readAs, location, tx, locationInfo) match {
case PartialTransaction.CompleteTransaction(tx, locationInfo, _) =>
ledger.commit(committers, readAs, location, enrich(tx), locationInfo) match {
case Left(err) =>
SubmissionError(err, onLedger.incompleteTransaction)
SubmissionError(err, enrich(onLedger.incompleteTransaction))
case Right(r) =>
Commit(r, resultValue, onLedger.incompleteTransaction)
Commit(r, resultValue, enrich(onLedger.incompleteTransaction))
}
case PartialTransaction.IncompleteTransaction(ptx) =>
throw new RuntimeException(s"Unexpected abort: $ptx")
}
case SResultError(err) =>
SubmissionError(Error.RunnerException(err), onLedger.incompleteTransaction)
SubmissionError(Error.RunnerException(err), enrich(onLedger.incompleteTransaction))
case SResultNeedContract(coid, tid @ _, committers, callback) =>
ledger.lookupContract(coid, committers, readAs, callback) match {
case Left(err) => SubmissionError(err, onLedger.incompleteTransaction)
case Left(err) => SubmissionError(err, enrich(onLedger.incompleteTransaction))
case Right(_) => go()
}
case SResultNeedKey(keyWithMaintainers, committers, callback) =>
Expand All @@ -455,7 +469,7 @@ object ScenarioRunner {
readAs,
callback,
) match {
case Left(err) => SubmissionError(err, onLedger.incompleteTransaction)
case Left(err) => SubmissionError(err, enrich(onLedger.incompleteTransaction))
case Right(_) => go()
}
case SResultNeedTime(callback) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,7 @@ package transaction

import com.daml.lf.data.Ref.Location

trait IncompleteTransaction {

type Nid = NodeId
type TX = GenTransaction
type ExerciseNode = Node.NodeExercises

def transaction: TX

def locationInfo: Map[Nid, Location]
}
final case class IncompleteTransaction(
transaction: GenTransaction,
locationInfo: Map[NodeId, Location],
)
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,12 @@ private[sandbox] final class InMemoryLedger(
)
}

private def enrichTX(tx: LedgerEntry.Transaction): LedgerEntry.Transaction = {
private def enrichTX(tx: LedgerEntry.Transaction): LedgerEntry.Transaction =
tx.copy(transaction =
CommittedTransaction(consumeEnricherResult(enricher.enrichTransaction(tx.transaction)))
CommittedTransaction(
consumeEnricherResult(enricher.enrichVersionedTransaction(tx.transaction))
)
)
}

private val logger = ContextualizedLogger.get(this.getClass)

Expand Down

0 comments on commit c8006b8

Please sign in to comment.