Skip to content

Commit

Permalink
Merge pull request #1419 from informalsystems/enable-records-in-array…
Browse files Browse the repository at this point in the history
…s-encoding

Make records as they are work in the arrays encoding
  • Loading branch information
rodrigo7491 committed Mar 8, 2022
2 parents bbdf123 + 93a5942 commit 6a3dd9f
Show file tree
Hide file tree
Showing 11 changed files with 150 additions and 162 deletions.
6 changes: 3 additions & 3 deletions UNRELEASED.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
DO NOT LEAVE A BLANK LINE BELOW THIS PREAMBLE -->
### Features

* Implement the sequence constructor `Apalache!MkSeq`, see #143
* Enable records in the arrays encoding, see #1288
* Implement the sequence constructor `Apalache!MkSeq`, see #1439
* Add support for `Apalache!FunAsSeq`, see #1442
* Implement `EXCEPT` on sequences, see #1444

### Bug fixes
* Fixed bug where TLA+ `LAMBDA`s wouldn't inline outside `Fold` and `MkSeq`, see #1446
* Fixed bug where TLA+ `LAMBDA`s wouldn't inline outside `Fold` and `MkSeq`, see #1446
Original file line number Diff line number Diff line change
Expand Up @@ -239,26 +239,40 @@ class LazyEquality(rewriter: SymbStateRewriter)
}

private def mkSetEq(state: SymbState, left: ArenaCell, right: ArenaCell): SymbState = {
if (left.cellType == FinSetT(UnknownT()) && state.arena.getHas(left).isEmpty) {
// The statically empty set is a very special case, as its element type is unknown.
// Hence, we cannot use SMT equality, as it does not work with different sorts.
mkEmptySetEq(state, left, right)
} else if (right.cellType == FinSetT(UnknownT()) && state.arena.getHas(right).isEmpty) {
mkEmptySetEq(state, right, left) // same here
} else {
// in general, we need 2 * |X| * |Y| comparisons
val leftToRight: SymbState = subsetEq(state, left, right)
val rightToLeft: SymbState = subsetEq(leftToRight, right, left)
// the type checker makes sure that this holds true
assert(left.cellType.signature == right.cellType.signature)
// These two sets have the same signature and thus belong to the same sort.
// Hence, we can use SMT equality. This equality is needed by uninterpreted functions.
val eq = tla.equiv(tla.eql(left.toNameEx, right.toNameEx), tla.and(leftToRight.ex, rightToLeft.ex))
rewriter.solverContext.assertGroundExpr(eq)
eqCache.put(left, right, EqCache.EqEntry())
rewriter.solverContext.config.smtEncoding match {
case `arraysEncoding` =>
// In the arrays encoding we only cache the equalities between the sets' elements
val leftElems = state.arena.getHas(left)
val rightElems = state.arena.getHas(right)
cacheEqConstraints(state, leftElems.cross(rightElems)) // cache all the equalities
eqCache.put(left, right, EqCache.EqEntry())
state

case `oopsla19Encoding` =>
if (left.cellType == FinSetT(UnknownT()) && state.arena.getHas(left).isEmpty) {
// The statically empty set is a very special case, as its element type is unknown.
// Hence, we cannot use SMT equality, as it does not work with different sorts.
mkEmptySetEq(state, left, right)
} else if (right.cellType == FinSetT(UnknownT()) && state.arena.getHas(right).isEmpty) {
mkEmptySetEq(state, right, left) // same here
} else {
// in general, we need 2 * |X| * |Y| comparisons
val leftToRight: SymbState = subsetEq(state, left, right)
val rightToLeft: SymbState = subsetEq(leftToRight, right, left)
// the type checker makes sure that this holds true
assert(left.cellType.signature == right.cellType.signature)
// These two sets have the same signature and thus belong to the same sort.
// Hence, we can use SMT equality.
val eq = tla.equiv(tla.eql(left.toNameEx, right.toNameEx), tla.and(leftToRight.ex, rightToLeft.ex))
rewriter.solverContext.assertGroundExpr(eq)
eqCache.put(left, right, EqCache.EqEntry())

// recover the original expression
rightToLeft.setRex(state.ex)
}

// recover the original expression and theory
rightToLeft.setRex(state.ex)
case oddEncodingType =>
throw new IllegalArgumentException(s"Unexpected SMT encoding of type $oddEncodingType")
}
}

Expand Down Expand Up @@ -438,13 +452,29 @@ class LazyEquality(rewriter: SymbStateRewriter)
private def mkFunEq(state: SymbState, leftFun: ArenaCell, rightFun: ArenaCell): SymbState = {
val leftRel = state.arena.getCdm(leftFun)
val rightRel = state.arena.getCdm(rightFun)
val relEq = mkSetEq(state, leftRel, rightRel)
rewriter.solverContext.assertGroundExpr(tla.equiv(tla.eql(leftFun.toNameEx, rightFun.toNameEx),
tla.eql(leftRel.toNameEx, rightRel.toNameEx)))
eqCache.put(leftFun, rightFun, EqCache.EqEntry())

// restore the original expression and theory
relEq.setRex(state.ex)
rewriter.solverContext.config.smtEncoding match {
case `arraysEncoding` =>
// In the arrays encoding we only cache the equalities between the elements of the functions' ranges
// This is because the ranges consist of pairs of form <arg,res>, thus the domains are also handled
val leftElems = state.arena.getHas(leftRel)
val rightElems = state.arena.getHas(rightRel)
cacheEqConstraints(state, leftElems.cross(rightElems)) // Cache all the equalities
eqCache.put(leftFun, rightFun, EqCache.EqEntry())
state

case `oopsla19Encoding` =>
val relEq = mkSetEq(state, leftRel, rightRel)
rewriter.solverContext.assertGroundExpr(tla.equiv(tla.eql(leftFun.toNameEx, rightFun.toNameEx),
tla.eql(leftRel.toNameEx, rightRel.toNameEx)))
eqCache.put(leftFun, rightFun, EqCache.EqEntry())

// Restore the original expression and theory
relEq.setRex(state.ex)

case oddEncodingType =>
throw new IllegalArgumentException(s"Unexpected SMT encoding of type $oddEncodingType")
}
}

private def mkRecordEq(state: SymbState, leftRec: ArenaCell, rightRec: ArenaCell): SymbState = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ class SymbStateRewriterImplWithArrays(

val newRules: Map[String, List[RewritingRule]] = {
Map(
// logic
key(tla.eql(tla.name("x"), tla.name("y")))
-> List(new EqRuleWithArrays(this)), // TODO: update with additional elements later
// sets
key(tla.in(tla.name("x"), tla.name("S")))
-> List(new SetInRuleWithArrays(this)), // TODO: add support for funSet later
Expand All @@ -62,58 +59,8 @@ class SymbStateRewriterImplWithArrays(

val unsupportedRules: Map[String, List[RewritingRule]] = {
Map(
// logic
key(tla.choose(tla.name("x"), tla.name("S"), tla.name("p")))
-> List(new ChooseRule(this)),
// tuples, records, and sequences
key(tla.head(tla.tuple(tla.name("x"))))
-> List(new SeqOpsRule(this)),
key(tla.tail(tla.tuple(tla.name("x"))))
-> List(new SeqOpsRule(this)),
key(tla.subseq(tla.tuple(tla.name("x")), tla.int(2), tla.int(4)))
-> List(new SeqOpsRule(this)),
key(tla.len(tla.tuple(tla.name("x"))))
-> List(new SeqOpsRule(this)),
key(tla.append(tla.tuple(tla.name("x")), tla.int(10)))
-> List(new SeqOpsRule(this)),
key(tla.concat(tla.name("Seq1"), tla.name("Seq2")))
-> List(new SeqOpsRule(this)),
key(OperEx(ApalacheOper.gen, tla.int(2)))
-> List(new GenRule(this)),
// folds
key(OperEx(ApalacheOper.foldSet, tla.name("A"), tla.name("v"), tla.name("S")))
-> List(new FoldSetRule(this)),
key(OperEx(ApalacheOper.foldSeq, tla.name("A"), tla.name("v"), tla.name("s")))
-> List(new FoldSeqRule(this)),
// -----------------------------------------------------------------------
// RULES BELOW WERE NOT REMOVED TO RUN TESTS, WILL BE LOOKED AT LATER
// -----------------------------------------------------------------------
/*
// logic
key(OperEx(ApalacheOper.skolem, tla.exists(tla.name("x"), tla.name("S"), tla.name("p"))))
-> List(new QuantRule(this)),
key(tla.exists(tla.name("x"), tla.name("S"), tla.name("p")))
-> List(new QuantRule(this)),
key(tla.forall(tla.name("x"), tla.name("S"), tla.name("p")))
-> List(new QuantRule(this)),
// control flow
key(tla.ite(tla.name("cond"), tla.name("then"), tla.name("else")))
-> List(new IfThenElseRule(this)),
key(tla.letIn(tla.int(1), TlaOperDecl("A", List(), tla.int(2))))
-> List(new LetInRule(this)),
key(tla.appDecl(TlaOperDecl("userOp", List(), tla.int(3)))) ->
List(new UserOperRule(this)),
// functions
key(tla.recFunDef(tla.name("e"), tla.name("x"), tla.name("S")))
-> List(new RecFunDefAndRefRule(this)),
key(tla.recFunRef())
-> List(new RecFunDefAndRefRule(this)),
// tuples, records, and sequences
key(tla.tuple(tla.name("x"), tla.int(2)))
-> List(new TupleOrSeqCtorRule(this)),
key(tla.enumFun(tla.str("a"), tla.int(2)))
-> List(new RecCtorRule(this))
*/
-> List(new GenRule(this))
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ class DomainRule(rewriter: SymbStateRewriter, intRangeCache: IntRangeCache) exte
val funState = rewriter.rewriteUntilDone(state.setRex(funEx))
val funCell = funState.asCell

// no type information from the type finder is needed, as types are propagated in a straightforward manner
funCell.cellType match {
case RecordT(_) =>
val dom = funState.arena.getDom(funCell)
Expand All @@ -57,13 +56,13 @@ class DomainRule(rewriter: SymbStateRewriter, intRangeCache: IntRangeCache) exte
}
}

private def mkTupleDomain(state: SymbState, funCell: ArenaCell): SymbState = {
protected def mkTupleDomain(state: SymbState, funCell: ArenaCell): SymbState = {
val tupleT = funCell.cellType.asInstanceOf[TupleT]
val (newArena, dom) = intRangeCache.create(state.arena, (1, tupleT.args.size))
state.setArena(newArena).setRex(dom.toNameEx)
}

private def mkSeqDomain(state: SymbState, seqCell: ArenaCell): SymbState = {
protected def mkSeqDomain(state: SymbState, seqCell: ArenaCell): SymbState = {
val (protoSeq, len, capacity) = proto.unpackSeq(state.arena, seqCell)
// We do not know the domain precisely, as it depends on the length of the sequence.
// Hence, we create the set `1..capacity` and include only those elements that are not greater than `len`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package at.forsyte.apalache.tla.bmcmt.rules

import at.forsyte.apalache.tla.bmcmt.{RewriterException, SymbState, SymbStateRewriter}
import at.forsyte.apalache.tla.bmcmt.caches.IntRangeCache
import at.forsyte.apalache.tla.bmcmt.types.FunT
import at.forsyte.apalache.tla.bmcmt.types.{FunT, RecordT, SeqT, TupleT}
import at.forsyte.apalache.tla.lir.OperEx
import at.forsyte.apalache.tla.lir.oper.TlaFunOper

Expand All @@ -21,14 +21,25 @@ class DomainRuleWithArrays(rewriter: SymbStateRewriter, intRangeCache: IntRangeC
val funState = rewriter.rewriteUntilDone(state.setRex(funEx))
val funCell = funState.asCell

// TODO: consider records, tuples, and sequences in the arrays encoding
funCell.cellType match {
case RecordT(_) =>
val dom = funState.arena.getDom(funCell)
funState.setRex(dom.toNameEx)

case TupleT(_) =>
mkTupleDomain(funState, funCell)

case SeqT(_) =>
mkSeqDomain(funState, funCell)

case FunT(_, _) =>
val dom = funState.arena.getDom(funCell)
funState.setRex(dom.toNameEx)

case _ =>
// TODO: consider records, tuples, and sequences in the arrays encoding
super.apply(state)
throw new RewriterException("DOMAIN x where type(x) = %s is not implemented".format(funCell.cellType),
state.ex)
}

case _ =>
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,30 @@ class FunAppRuleWithArrays(rewriter: SymbStateRewriter) extends FunAppRule(rewri
var nextState = rewriter.rewriteUntilDone(state.setRex(argEx))
val argCell = nextState.asCell

val domainCell = nextState.arena.getDom(funCell)
val relationCell = nextState.arena.getCdm(funCell)
val relationElems = nextState.arena.getHas(relationCell)
val nElems = relationElems.size

nextState = nextState.updateArena(_.appendCell(elemT, isUnconstrained = true)) // The cell will be unconstrained
val res = nextState.arena.topCell

// If the domain is non-empty we query the array representing the function
// if the domain is empty we return an unconstrained value
if (nElems > 0) {
// The SMT constraints are produced here
val select = tla.apalacheSelectInFun(argCell.toNameEx, funCell.toNameEx)
val eql = tla.eql(res.toNameEx, select)
rewriter.solverContext.assertGroundExpr(eql)
var comparableArgCell = false

// The cell metadata is propagated here
// We compare argCell with relationElems and if an equality is found we simply propagate elemRes
for (elem <- relationElems) {
val elemTuple = nextState.arena.getHas(elem)
assert(elemTuple.size == 2) // elem should always have only edges to <arg,res>
val elemArg = elemTuple(0)
val elemRes = elemTuple(1)
if (elemArg == argCell) {
if (argCell == elemArg) {
comparableArgCell = true
// If argCell is comparable at the Scala level, we generate SMT constraints based on it
val select = tla.apalacheSelectInFun(elemArg.toNameEx, funCell.toNameEx)
val eql = tla.eql(res.toNameEx, select)
rewriter.solverContext.assertGroundExpr(eql)

nextState = nextState.updateArena(_.appendHasNoSmt(res, nextState.arena.getHas(elemRes): _*))

if (elemRes.cellType.isInstanceOf[FunT] || elemRes.cellType.isInstanceOf[FinFunSetT]) {
Expand All @@ -56,6 +58,54 @@ class FunAppRuleWithArrays(rewriter: SymbStateRewriter) extends FunAppRule(rewri
}
}
}

// If argCell is not comparable at the Scala level, e.g., due to quantifier use, we need to use an oracle
if (!comparableArgCell) {
// We use an oracle to pick an arg for which the function is applied
val (oracleState, oracle) = picker.oracleFactory.newDefaultOracle(nextState, nElems + 1)
nextState = picker.pickByOracle(oracleState, oracle, relationElems, oracleState.arena.cellTrue().toNameEx)
val pickedElem = nextState.asCell

assert(nextState.arena.getHas(pickedElem).size == 2)
val pickedArg = nextState.arena.getHas(pickedElem)(0)
val pickedRes = nextState.arena.getHas(pickedElem)(1)

// Cache the equality between the picked argument and the supplied argument, O(1) constraints
nextState = rewriter.lazyEq.cacheEqConstraints(nextState, Seq((pickedArg, argCell)))
// If oracle < N, then pickedArg = argCell
val pickedElemInDom = tla.not(oracle.whenEqualTo(nextState, nElems))
val argEq = tla.eql(pickedArg.toNameEx, argCell.toNameEx)
val argEqWhenNonEmpty = tla.impl(pickedElemInDom, argEq)
rewriter.solverContext.assertGroundExpr(argEqWhenNonEmpty)

// We require oracle = N iff there is no element equal to argCell, O(N) constraints
for ((elem, no) <- relationElems.zipWithIndex) {
val elemArg = nextState.arena.getHas(elem).head
nextState = rewriter.lazyEq.cacheEqConstraints(nextState, Seq((elemArg, argCell)))
val inDom = tla.apalacheSelectInSet(elemArg.toNameEx, domainCell.toNameEx)
val neqArg = tla.not(rewriter.lazyEq.safeEq(elemArg, argCell))
val found = tla.not(oracle.whenEqualTo(nextState, nElems))
// ~inDom \/ neqArg \/ found, or equivalently, (inDom /\ elemArg = argCell) => found
rewriter.solverContext.assertGroundExpr(tla.or(tla.not(inDom), neqArg, found))
// oracle = i => inDom. The equality pickedArg = argCell is enforced by argEqWhenNonEmpty
rewriter.solverContext.assertGroundExpr(tla.or(tla.not(oracle.whenEqualTo(nextState, no)), inDom))
}

// We simply apply the picked arg to the SMT array encoding the function, O(1) constraints
val select = tla.apalacheSelectInFun(argCell.toNameEx, funCell.toNameEx)
val eql = tla.eql(res.toNameEx, select)
rewriter.solverContext.assertGroundExpr(eql)

// The edges, dom, and cdm are forwarded below
nextState = nextState.updateArena(_.appendHasNoSmt(res, nextState.arena.getHas(pickedRes): _*))
if (pickedRes.cellType.isInstanceOf[FunT] || pickedRes.cellType.isInstanceOf[FinFunSetT]) {
nextState = nextState.updateArena(_.setDom(res, nextState.arena.getDom(pickedRes)))
nextState = nextState.updateArena(_.setCdm(res, nextState.arena.getCdm(pickedRes)))
} else if (pickedRes.cellType.isInstanceOf[RecordT]) {
// Records do not contain cdm metadata
nextState = nextState.updateArena(_.setDom(res, nextState.arena.getDom(pickedRes)))
}
}
}

nextState.setRex(res.toNameEx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class SetInRuleWithArrays(rewriter: SymbStateRewriter) extends SetInRule(rewrite
funsetCdm.cellType match {
case _: PowSetT =>
nextState = rewriter.rewriteUntilDone(nextState.setRex(tla.appFun(funCell.toNameEx, funsetDomElem.toNameEx)))
val funAppRes = nextState.arena.topCell
val funAppRes = nextState.asCell
val powSetDom = nextState.arena.getDom(funsetCdm)
val subsetEq = tla.subseteq(funAppRes.toNameEx, powSetDom.toNameEx)
nextState = rewriter.rewriteUntilDone(nextState.setRex(subsetEq))
Expand Down
Loading

0 comments on commit 6a3dd9f

Please sign in to comment.