Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make records as they are work in the arrays encoding #1419

Merged
merged 16 commits into from
Mar 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions UNRELEASED.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

DO NOT LEAVE A BLANK LINE BELOW THIS PREAMBLE -->
### Features

* Implement the sequence constructor `Apalache!MkSeq`, see #143
* Enable records in the arrays encoding, see #1288
* Implement the sequence constructor `Apalache!MkSeq`, see #1439
* Add support for `Apalache!FunAsSeq`, see #1442
* Implement `EXCEPT` on sequences, see #1444

### Bug fixes
* Fixed bug where TLA+ `LAMBDA`s wouldn't inline outside `Fold` and `MkSeq`, see #1446
* Fixed bug where TLA+ `LAMBDA`s wouldn't inline outside `Fold` and `MkSeq`, see #1446
Original file line number Diff line number Diff line change
Expand Up @@ -239,26 +239,40 @@ class LazyEquality(rewriter: SymbStateRewriter)
}

private def mkSetEq(state: SymbState, left: ArenaCell, right: ArenaCell): SymbState = {
if (left.cellType == FinSetT(UnknownT()) && state.arena.getHas(left).isEmpty) {
// The statically empty set is a very special case, as its element type is unknown.
// Hence, we cannot use SMT equality, as it does not work with different sorts.
mkEmptySetEq(state, left, right)
} else if (right.cellType == FinSetT(UnknownT()) && state.arena.getHas(right).isEmpty) {
mkEmptySetEq(state, right, left) // same here
} else {
// in general, we need 2 * |X| * |Y| comparisons
val leftToRight: SymbState = subsetEq(state, left, right)
val rightToLeft: SymbState = subsetEq(leftToRight, right, left)
// the type checker makes sure that this holds true
assert(left.cellType.signature == right.cellType.signature)
// These two sets have the same signature and thus belong to the same sort.
// Hence, we can use SMT equality. This equality is needed by uninterpreted functions.
val eq = tla.equiv(tla.eql(left.toNameEx, right.toNameEx), tla.and(leftToRight.ex, rightToLeft.ex))
rewriter.solverContext.assertGroundExpr(eq)
eqCache.put(left, right, EqCache.EqEntry())
rewriter.solverContext.config.smtEncoding match {
case `arraysEncoding` =>
// In the arrays encoding we only cache the equalities between the sets' elements
val leftElems = state.arena.getHas(left)
val rightElems = state.arena.getHas(right)
cacheEqConstraints(state, leftElems.cross(rightElems)) // cache all the equalities
eqCache.put(left, right, EqCache.EqEntry())
state

case `oopsla19Encoding` =>
if (left.cellType == FinSetT(UnknownT()) && state.arena.getHas(left).isEmpty) {
// The statically empty set is a very special case, as its element type is unknown.
// Hence, we cannot use SMT equality, as it does not work with different sorts.
mkEmptySetEq(state, left, right)
} else if (right.cellType == FinSetT(UnknownT()) && state.arena.getHas(right).isEmpty) {
mkEmptySetEq(state, right, left) // same here
} else {
Comment on lines +252 to +258
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very old piece of code. Let's replace it with an assertion that both left.cellType and right.cellType are FinSetT(_).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #1430.

// in general, we need 2 * |X| * |Y| comparisons
val leftToRight: SymbState = subsetEq(state, left, right)
val rightToLeft: SymbState = subsetEq(leftToRight, right, left)
// the type checker makes sure that this holds true
assert(left.cellType.signature == right.cellType.signature)
// These two sets have the same signature and thus belong to the same sort.
// Hence, we can use SMT equality.
val eq = tla.equiv(tla.eql(left.toNameEx, right.toNameEx), tla.and(leftToRight.ex, rightToLeft.ex))
rewriter.solverContext.assertGroundExpr(eq)
eqCache.put(left, right, EqCache.EqEntry())

// recover the original expression
rightToLeft.setRex(state.ex)
}

// recover the original expression and theory
rightToLeft.setRex(state.ex)
case oddEncodingType =>
throw new IllegalArgumentException(s"Unexpected SMT encoding of type $oddEncodingType")
}
}

Expand Down Expand Up @@ -438,13 +452,29 @@ class LazyEquality(rewriter: SymbStateRewriter)
private def mkFunEq(state: SymbState, leftFun: ArenaCell, rightFun: ArenaCell): SymbState = {
val leftRel = state.arena.getCdm(leftFun)
val rightRel = state.arena.getCdm(rightFun)
val relEq = mkSetEq(state, leftRel, rightRel)
rewriter.solverContext.assertGroundExpr(tla.equiv(tla.eql(leftFun.toNameEx, rightFun.toNameEx),
tla.eql(leftRel.toNameEx, rightRel.toNameEx)))
eqCache.put(leftFun, rightFun, EqCache.EqEntry())

// restore the original expression and theory
relEq.setRex(state.ex)
rewriter.solverContext.config.smtEncoding match {
case `arraysEncoding` =>
// In the arrays encoding we only cache the equalities between the elements of the functions' ranges
rodrigo7491 marked this conversation as resolved.
Show resolved Hide resolved
// This is because the ranges consist of pairs of form <arg,res>, thus the domains are also handled
val leftElems = state.arena.getHas(leftRel)
val rightElems = state.arena.getHas(rightRel)
cacheEqConstraints(state, leftElems.cross(rightElems)) // Cache all the equalities
eqCache.put(leftFun, rightFun, EqCache.EqEntry())
state

case `oopsla19Encoding` =>
val relEq = mkSetEq(state, leftRel, rightRel)
rewriter.solverContext.assertGroundExpr(tla.equiv(tla.eql(leftFun.toNameEx, rightFun.toNameEx),
tla.eql(leftRel.toNameEx, rightRel.toNameEx)))
eqCache.put(leftFun, rightFun, EqCache.EqEntry())

// Restore the original expression and theory
relEq.setRex(state.ex)

case oddEncodingType =>
throw new IllegalArgumentException(s"Unexpected SMT encoding of type $oddEncodingType")
}
}

private def mkRecordEq(state: SymbState, leftRec: ArenaCell, rightRec: ArenaCell): SymbState = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ class SymbStateRewriterImplWithArrays(

val newRules: Map[String, List[RewritingRule]] = {
Map(
// logic
key(tla.eql(tla.name("x"), tla.name("y")))
-> List(new EqRuleWithArrays(this)), // TODO: update with additional elements later
// sets
key(tla.in(tla.name("x"), tla.name("S")))
-> List(new SetInRuleWithArrays(this)), // TODO: add support for funSet later
Expand All @@ -62,58 +59,8 @@ class SymbStateRewriterImplWithArrays(

val unsupportedRules: Map[String, List[RewritingRule]] = {
Map(
// logic
key(tla.choose(tla.name("x"), tla.name("S"), tla.name("p")))
-> List(new ChooseRule(this)),
// tuples, records, and sequences
key(tla.head(tla.tuple(tla.name("x"))))
-> List(new SeqOpsRule(this)),
key(tla.tail(tla.tuple(tla.name("x"))))
-> List(new SeqOpsRule(this)),
key(tla.subseq(tla.tuple(tla.name("x")), tla.int(2), tla.int(4)))
-> List(new SeqOpsRule(this)),
key(tla.len(tla.tuple(tla.name("x"))))
-> List(new SeqOpsRule(this)),
key(tla.append(tla.tuple(tla.name("x")), tla.int(10)))
-> List(new SeqOpsRule(this)),
key(tla.concat(tla.name("Seq1"), tla.name("Seq2")))
-> List(new SeqOpsRule(this)),
key(OperEx(ApalacheOper.gen, tla.int(2)))
-> List(new GenRule(this)),
// folds
key(OperEx(ApalacheOper.foldSet, tla.name("A"), tla.name("v"), tla.name("S")))
-> List(new FoldSetRule(this)),
key(OperEx(ApalacheOper.foldSeq, tla.name("A"), tla.name("v"), tla.name("s")))
-> List(new FoldSeqRule(this)),
// -----------------------------------------------------------------------
// RULES BELOW WERE NOT REMOVED TO RUN TESTS, WILL BE LOOKED AT LATER
// -----------------------------------------------------------------------
/*
// logic
key(OperEx(ApalacheOper.skolem, tla.exists(tla.name("x"), tla.name("S"), tla.name("p"))))
-> List(new QuantRule(this)),
key(tla.exists(tla.name("x"), tla.name("S"), tla.name("p")))
-> List(new QuantRule(this)),
key(tla.forall(tla.name("x"), tla.name("S"), tla.name("p")))
-> List(new QuantRule(this)),
// control flow
key(tla.ite(tla.name("cond"), tla.name("then"), tla.name("else")))
-> List(new IfThenElseRule(this)),
key(tla.letIn(tla.int(1), TlaOperDecl("A", List(), tla.int(2))))
-> List(new LetInRule(this)),
key(tla.appDecl(TlaOperDecl("userOp", List(), tla.int(3)))) ->
List(new UserOperRule(this)),
// functions
key(tla.recFunDef(tla.name("e"), tla.name("x"), tla.name("S")))
-> List(new RecFunDefAndRefRule(this)),
key(tla.recFunRef())
-> List(new RecFunDefAndRefRule(this)),
// tuples, records, and sequences
key(tla.tuple(tla.name("x"), tla.int(2)))
-> List(new TupleOrSeqCtorRule(this)),
key(tla.enumFun(tla.str("a"), tla.int(2)))
-> List(new RecCtorRule(this))
*/
-> List(new GenRule(this))
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ class DomainRule(rewriter: SymbStateRewriter, intRangeCache: IntRangeCache) exte
val funState = rewriter.rewriteUntilDone(state.setRex(funEx))
val funCell = funState.asCell

// no type information from the type finder is needed, as types are propagated in a straightforward manner
funCell.cellType match {
case RecordT(_) =>
val dom = funState.arena.getDom(funCell)
Expand All @@ -57,13 +56,13 @@ class DomainRule(rewriter: SymbStateRewriter, intRangeCache: IntRangeCache) exte
}
}

private def mkTupleDomain(state: SymbState, funCell: ArenaCell): SymbState = {
protected def mkTupleDomain(state: SymbState, funCell: ArenaCell): SymbState = {
val tupleT = funCell.cellType.asInstanceOf[TupleT]
val (newArena, dom) = intRangeCache.create(state.arena, (1, tupleT.args.size))
state.setArena(newArena).setRex(dom.toNameEx)
}

private def mkSeqDomain(state: SymbState, seqCell: ArenaCell): SymbState = {
protected def mkSeqDomain(state: SymbState, seqCell: ArenaCell): SymbState = {
val (protoSeq, len, capacity) = proto.unpackSeq(state.arena, seqCell)
// We do not know the domain precisely, as it depends on the length of the sequence.
// Hence, we create the set `1..capacity` and include only those elements that are not greater than `len`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package at.forsyte.apalache.tla.bmcmt.rules

import at.forsyte.apalache.tla.bmcmt.{RewriterException, SymbState, SymbStateRewriter}
import at.forsyte.apalache.tla.bmcmt.caches.IntRangeCache
import at.forsyte.apalache.tla.bmcmt.types.FunT
import at.forsyte.apalache.tla.bmcmt.types.{FunT, RecordT, SeqT, TupleT}
import at.forsyte.apalache.tla.lir.OperEx
import at.forsyte.apalache.tla.lir.oper.TlaFunOper

Expand All @@ -21,14 +21,25 @@ class DomainRuleWithArrays(rewriter: SymbStateRewriter, intRangeCache: IntRangeC
val funState = rewriter.rewriteUntilDone(state.setRex(funEx))
val funCell = funState.asCell

// TODO: consider records, tuples, and sequences in the arrays encoding
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this TODO obsolete?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. It is there as a reminder of what was not touched for when we decide to encode these features differently.

funCell.cellType match {
case RecordT(_) =>
val dom = funState.arena.getDom(funCell)
funState.setRex(dom.toNameEx)

case TupleT(_) =>
mkTupleDomain(funState, funCell)

case SeqT(_) =>
mkSeqDomain(funState, funCell)

case FunT(_, _) =>
val dom = funState.arena.getDom(funCell)
funState.setRex(dom.toNameEx)

case _ =>
// TODO: consider records, tuples, and sequences in the arrays encoding
super.apply(state)
throw new RewriterException("DOMAIN x where type(x) = %s is not implemented".format(funCell.cellType),
state.ex)
}

case _ =>
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,30 @@ class FunAppRuleWithArrays(rewriter: SymbStateRewriter) extends FunAppRule(rewri
var nextState = rewriter.rewriteUntilDone(state.setRex(argEx))
val argCell = nextState.asCell

val domainCell = nextState.arena.getDom(funCell)
val relationCell = nextState.arena.getCdm(funCell)
val relationElems = nextState.arena.getHas(relationCell)
val nElems = relationElems.size

nextState = nextState.updateArena(_.appendCell(elemT, isUnconstrained = true)) // The cell will be unconstrained
val res = nextState.arena.topCell

// If the domain is non-empty we query the array representing the function
// if the domain is empty we return an unconstrained value
if (nElems > 0) {
// The SMT constraints are produced here
val select = tla.apalacheSelectInFun(argCell.toNameEx, funCell.toNameEx)
val eql = tla.eql(res.toNameEx, select)
rewriter.solverContext.assertGroundExpr(eql)
var comparableArgCell = false

// The cell metadata is propagated here
// We compare argCell with relationElems and if an equality is found we simply propagate elemRes
for (elem <- relationElems) {
val elemTuple = nextState.arena.getHas(elem)
assert(elemTuple.size == 2) // elem should always have only edges to <arg,res>
val elemArg = elemTuple(0)
val elemRes = elemTuple(1)
if (elemArg == argCell) {
if (argCell == elemArg) {
comparableArgCell = true
// If argCell is comparable at the Scala level, we generate SMT constraints based on it
val select = tla.apalacheSelectInFun(elemArg.toNameEx, funCell.toNameEx)
val eql = tla.eql(res.toNameEx, select)
rewriter.solverContext.assertGroundExpr(eql)

nextState = nextState.updateArena(_.appendHasNoSmt(res, nextState.arena.getHas(elemRes): _*))

if (elemRes.cellType.isInstanceOf[FunT] || elemRes.cellType.isInstanceOf[FinFunSetT]) {
Expand All @@ -56,6 +58,54 @@ class FunAppRuleWithArrays(rewriter: SymbStateRewriter) extends FunAppRule(rewri
}
}
}

// If argCell is not comparable at the Scala level, e.g., due to quantifier use, we need to use an oracle
if (!comparableArgCell) {
// We use an oracle to pick an arg for which the function is applied
val (oracleState, oracle) = picker.oracleFactory.newDefaultOracle(nextState, nElems + 1)
nextState = picker.pickByOracle(oracleState, oracle, relationElems, oracleState.arena.cellTrue().toNameEx)
val pickedElem = nextState.asCell

assert(nextState.arena.getHas(pickedElem).size == 2)
val pickedArg = nextState.arena.getHas(pickedElem)(0)
val pickedRes = nextState.arena.getHas(pickedElem)(1)

// Cache the equality between the picked argument and the supplied argument, O(1) constraints
nextState = rewriter.lazyEq.cacheEqConstraints(nextState, Seq((pickedArg, argCell)))
// If oracle < N, then pickedArg = argCell
val pickedElemInDom = tla.not(oracle.whenEqualTo(nextState, nElems))
val argEq = tla.eql(pickedArg.toNameEx, argCell.toNameEx)
val argEqWhenNonEmpty = tla.impl(pickedElemInDom, argEq)
rewriter.solverContext.assertGroundExpr(argEqWhenNonEmpty)

// We require oracle = N iff there is no element equal to argCell, O(N) constraints
for ((elem, no) <- relationElems.zipWithIndex) {
val elemArg = nextState.arena.getHas(elem).head
nextState = rewriter.lazyEq.cacheEqConstraints(nextState, Seq((elemArg, argCell)))
val inDom = tla.apalacheSelectInSet(elemArg.toNameEx, domainCell.toNameEx)
val neqArg = tla.not(rewriter.lazyEq.safeEq(elemArg, argCell))
val found = tla.not(oracle.whenEqualTo(nextState, nElems))
// ~inDom \/ neqArg \/ found, or equivalently, (inDom /\ elemArg = argCell) => found
rewriter.solverContext.assertGroundExpr(tla.or(tla.not(inDom), neqArg, found))
// oracle = i => inDom. The equality pickedArg = argCell is enforced by argEqWhenNonEmpty
rewriter.solverContext.assertGroundExpr(tla.or(tla.not(oracle.whenEqualTo(nextState, no)), inDom))
}

// We simply apply the picked arg to the SMT array encoding the function, O(1) constraints
val select = tla.apalacheSelectInFun(argCell.toNameEx, funCell.toNameEx)
val eql = tla.eql(res.toNameEx, select)
rewriter.solverContext.assertGroundExpr(eql)

// The edges, dom, and cdm are forwarded below
nextState = nextState.updateArena(_.appendHasNoSmt(res, nextState.arena.getHas(pickedRes): _*))
if (pickedRes.cellType.isInstanceOf[FunT] || pickedRes.cellType.isInstanceOf[FinFunSetT]) {
nextState = nextState.updateArena(_.setDom(res, nextState.arena.getDom(pickedRes)))
nextState = nextState.updateArena(_.setCdm(res, nextState.arena.getCdm(pickedRes)))
} else if (pickedRes.cellType.isInstanceOf[RecordT]) {
// Records do not contain cdm metadata
nextState = nextState.updateArena(_.setDom(res, nextState.arena.getDom(pickedRes)))
}
}
}

nextState.setRex(res.toNameEx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class SetInRuleWithArrays(rewriter: SymbStateRewriter) extends SetInRule(rewrite
funsetCdm.cellType match {
case _: PowSetT =>
nextState = rewriter.rewriteUntilDone(nextState.setRex(tla.appFun(funCell.toNameEx, funsetDomElem.toNameEx)))
val funAppRes = nextState.arena.topCell
val funAppRes = nextState.asCell
val powSetDom = nextState.arena.getDom(funsetCdm)
val subsetEq = tla.subseteq(funAppRes.toNameEx, powSetDom.toNameEx)
nextState = rewriter.rewriteUntilDone(nextState.setRex(subsetEq))
Expand Down
Loading