Skip to content

Commit

Permalink
Introduce the Repeat Apalache operator (#2927)
Browse files Browse the repository at this point in the history
* Added the Repeat operator

* integration tests + rule fix

* Typo caught by @thpani

* fmt fix

* PR comments

* fmt-fix

---------

Co-authored-by: Igor Konnov <igor@konnov.phd>
  • Loading branch information
Kukovec and konnov authored Aug 14, 2024
1 parent 8206471 commit b54d097
Show file tree
Hide file tree
Showing 20 changed files with 349 additions and 4 deletions.
37 changes: 37 additions & 0 deletions docs/src/lang/apalache-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,44 @@ IN
The operators `ApaFoldSet` and `ApaFoldSeqLeft` are explained in more detail in a dedicated section [here](../apalache/principles/folds.md).

----------------------------------------------------------------------------
<a name="Repeat"></a>
## Operator iteration

**Notation:** `Repeat(Op, N, x)`

**LaTeX notation:** `Repeat(Op, N, x)`

**Arguments:** Three arguments: An operator `Op`, an iteration counter `N` (a nonnegative constant integer expression), and an
initial value `x`.

**Apalache type:** `((a, Int), Int, a) => a`, for some type `a`.

**Effect:** For a given constant bound `N`, computes the value
`F(F(F(F(x,1), 2), ...), N)`. If `N=0` it evaluates to `x`.

```tla
Repeat(Op, N, x) ==
ApaFoldSeqLeft(Op, x, MkSeq(N, LAMBDA i:i))
```

Apalache implements a more efficient encoding of this operator than the default one.

**Determinism:** Deterministic.

**Errors:**
If any argument is ill-typed, or `N` is not a nonnegative constant integer expression, Apalache reports an error.

**Example in TLA+:**

```tla
Op(a) == a + 1
LET OpModified(a,i) == Op(i)
IN Repeat(OpModified, 0, 5) = 5 \* TRUE
Op2(a,i) == a + i
Repeat(Op2, 0, 5) = 15 \* TRUE
```
----------------------------------------------------------------------------
<a name="SetAsFun"></a>

## Convert a set of pairs to a function
Expand Down
13 changes: 13 additions & 0 deletions src/tla/Apalache.tla
Original file line number Diff line number Diff line change
Expand Up @@ -152,4 +152,17 @@ ApaFoldSeqLeft(__Op(_,_), __v, __seq) ==
THEN __v
ELSE ApaFoldSeqLeft(__Op, __Op(__v, Head(__seq)), Tail(__seq))

(**
* The repetition operator, used to consecutively apply an operator, starting from
* an initial value.
*
* @type: ((a, Int) => a, Int, a) => a;
*)
RECURSIVE Repeat(_,_,_)
Repeat(__F(_,_), __N, __x) ==
\* This is the TLC implementation. Apalache does it differently.
IF __N <= 0
THEN __x
ELSE __F(Repeat(__F, __N - 1, __x), __N)

===============================================================================
66 changes: 66 additions & 0 deletions test/tla/Repeat.tla
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
------------------------ MODULE Repeat ------------------------------------
EXTENDS Apalache, Integers

Inv1 ==
LET Op(a, i) == a + 1
IN Repeat(Op, 5, 0) = 5

Inv2 ==
LET Op(a, i) == a + i
IN Repeat(Op, 5, 0) = 15

\* Cyclical Op: \E k: Op^k = Id
Inv3 ==
LET Op(a,i) == (a + i) % 3
IN LET
v == 1
x0 == Repeat(Op, 0, v)
x3 == Repeat(Op, 3, v)
x6 == Repeat(Op, 6, v)
IN
/\ v = x0
/\ x0 = x3
/\ x3 = x6

\* Idempotent Op: Op^2 = Op
Inv4 ==
LET
\* @type: (Set(Set(Int)), Int) => Set(Set(Int));
Op(a, i) == {UNION a}
IN LET
v == {{1}, {2}, {3,4}}
x1 == Repeat(Op, 1, v)
x2 == Repeat(Op, 2, v)
x3 == Repeat(Op, 3, v)
IN
/\ v /= x1
/\ x1 = x2
/\ x2 = x3

\* Nilpotent Op: \E k: Op^k = 0
Inv5 ==
LET
\* @type: (Set(Int), Int) => Set(Int);
Op(a, i) == a \ { x \in a: \A y \in a: x <= y }
IN LET
v == {1,2,3}
x1 == Repeat(Op, 2, v)
x2 == Repeat(Op, 3, v)
x3 == Repeat(Op, 4, v)
IN
/\ x1 /= x2
/\ x2 = x3
/\ x3 = {}


Init == TRUE
Next == TRUE

Inv ==
/\ Inv1
/\ Inv2
/\ Inv3
/\ Inv4
/\ Inv5

===============================================================================
11 changes: 11 additions & 0 deletions test/tla/RepeatBad.tla
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
------------------------ MODULE RepeatBad ------------------------------------
EXTENDS Apalache, Integers

Inv ==
LET Op(a, i) == a + 1
\* The 2nd argument to Repeat must be an integer literal
IN \E x \in {5} : Repeat(Op, x, 0) = 5

Init == TRUE
Next == TRUE
===============================================================================
18 changes: 18 additions & 0 deletions test/tla/cli-integration-tests.md
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,24 @@ The outcome is: NoError
EXITCODE: OK
```
### check Repeat succeeds
```sh
$ apalache-mc check --inv=Inv --length=0 Repeat.tla | sed 's/I@.*//'
...
The outcome is: NoError
...
EXITCODE: OK
```
### check RepeatBad fails
```sh
$ apalache-mc check --inv=Inv --length=0 RepeatBad.tla | sed 's/I@.*//'
...
EXITCODE: ERROR (255)
```
### check Counter.tla errors (array-encoding)
```sh
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,13 @@ class SymbStateRewriterImpl(
-> List(new LabelRule(this)),
key(OperEx(ApalacheOper.gen, tla.int(2)))
-> List(new GenRule(this)),
// folds and MkSeq
// folds, repeat and MkSeq
key(OperEx(ApalacheOper.foldSet, tla.name("A"), tla.name("v"), tla.name("S")))
-> List(new FoldSetRule(this, renaming)),
key(OperEx(ApalacheOper.foldSeq, tla.name("A"), tla.name("v"), tla.name("s")))
-> List(new FoldSeqRule(this, renaming)),
key(OperEx(ApalacheOper.repeat, tla.name("Op"), tla.int(10), tla.name("s")))
-> List(new RepeatRule(this, renaming)),
key(OperEx(ApalacheOper.mkSeq, tla.int(10), tla.name("A")))
-> List(new MkSeqRule(this, renaming)),
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package at.forsyte.apalache.tla.bmcmt.rules

import at.forsyte.apalache.tla.bmcmt._
import at.forsyte.apalache.tla.lir._
import at.forsyte.apalache.tla.lir.oper.ApalacheOper
import at.forsyte.apalache.tla.lir.transformations.impl.IdleTracker
import at.forsyte.apalache.tla.lir.transformations.standard.IncrementalRenaming
import at.forsyte.apalache.tla.lir.values.TlaInt
import at.forsyte.apalache.tla.pp.Inliner
import at.forsyte.apalache.tla.types.tla

/**
* Rewriting rule for Repeat. This rule is similar to [[FoldSeqRule]].
*
* This rule is more efficient than the fold- rules, because it does not require an underlying data structure (Seq or
* Set). In particular, folding over 1..N, despite the overapproximation being tight by construction, still introduces
* O(N*N) constraints, since the SMT solver must assert element uniqueness (i /= j for all i,j \in 1..N). OTOH,
* RepeatRule consumes 1..N in the canonical order natively as a 1.to(N) in Scala.
*
* @author
* Jure Kukovec
*/
class RepeatRule(rewriter: SymbStateRewriter, renaming: IncrementalRenaming) extends RewritingRule {

override def isApplicable(symbState: SymbState): Boolean = {
symbState.ex match {
case OperEx(ApalacheOper.repeat, LetInEx(NameEx(appName), TlaOperDecl(operName, params, _)), _, _) =>
appName == operName && params.size == 2
case _ => false
}
}

override def apply(state: SymbState): SymbState = state.ex match {
// assume isApplicable
case ex @ OperEx(ApalacheOper.repeat, LetInEx(NameEx(_), opDecl), boundEx, baseEx) =>
boundEx match {
case ValEx(TlaInt(n)) if n.isValidInt && n.toInt >= 0 =>
// rewrite baseEx to its final cell form
val baseState = rewriter.rewriteUntilDone(state.setRex(baseEx))

// We need the type signature. The signature of Repeat is
// \A a : ((a,Int) => a, Int, a) => a
// so the operator type must be (a,Int) => a
val a = TlaType1.fromTypeTag(baseEx.typeTag)
val opT = OperT1(Seq(a, IntT1), a)
// sanity check
TlaType1.fromTypeTag(opDecl.typeTag) match {
case `opT` => // all good
case badType =>
val msg = s"FoldSet argument ${opDecl.name} should have the tag $opT, found $badType."
throw new TypingException(msg, opDecl.ID)
}

// expressions are transient, we don't need tracking
val inliner = new Inliner(new IdleTracker, renaming)
// We can make the scope directly, since InlinePass already ensures all is well.
val seededScope: Inliner.Scope = Map(opDecl.name -> opDecl)

// To implement the Repeat rule, we generate a sequence of cells.
// At each step, we perform one application of the operator
// defined by `opDecl` to the previous partial result,
// and pass the iteration index as the second parameter
(1 to n.toInt).foldLeft(baseState) { case (partialState, i) =>
// partialState currently holds the cell representing the previous application step
val oldResultCell = partialState.asCell

// First, we inline the operator application, with cell names as parameters
val appEx = tla.appOp(
tla.name(opDecl.name, opT),
oldResultCell.toBuilder,
tla.int(i),
)

val inlinedEx = inliner.transform(seededScope)(appEx)
rewriter.rewriteUntilDone(partialState.setRex(inlinedEx))
}
case _ =>
throw new RewriterException("Apalache!Repeat expects a constant positive integer. Found: " + boundEx, ex)
}
case _ =>
throw new RewriterException("%s is not applicable".format(getClass.getSimpleName), state.ex)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ object TlaExUtil {
findRec(localBody)
baseExAndCollectionEx.foreach(findRec)

// ignore the names in the auxiliary let-in definition
case OperEx(ApalacheOper.repeat, LetInEx(_, TlaOperDecl(_, _, localBody)), boundAndBaseEx @ _*) =>
findRec(localBody)
boundAndBaseEx.foreach(findRec)

// ignore the names in the auxiliary let-in definition
case OperEx(ApalacheOper.mkSeq, len, LetInEx(_, TlaOperDecl(_, _, localBody))) =>
findRec(localBody)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class TestRewriterWithArrays
with TestSymbStateRewriterPowerset with TestSymbStateRewriterRecord with TestSymbStateRewriterSequence
with TestSymbStateRewriterSet with TestSymbStateRewriterStr with TestSymbStateRewriterTuple
with TestPropositionalOracle with TestSparseOracle with TestUninterpretedConstOracle
with TestSymbStateRewriterApalache with TestSymbStateRewriterMkSeq {
with TestSymbStateRewriterApalache with TestSymbStateRewriterMkSeq with TestSymbStateRewriterRepeat {
override protected def withFixture(test: OneArgTest): Outcome = {
solverContext = new PreproSolverContext(new Z3SolverContext(SolverConfig.default.copy(debug = true,
smtEncoding = SMTEncoding.Arrays)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class TestRewriterWithFunArrays
with TestSymbStateRewriterPowerset with TestSymbStateRewriterRecord with TestSymbStateRewriterSequence
with TestSymbStateRewriterSet with TestSymbStateRewriterStr with TestSymbStateRewriterTuple
with TestPropositionalOracle with TestSparseOracle with TestUninterpretedConstOracle
with TestSymbStateRewriterApalache with TestSymbStateRewriterMkSeq {
with TestSymbStateRewriterApalache with TestSymbStateRewriterMkSeq with TestSymbStateRewriterRepeat {
override protected def withFixture(test: OneArgTest): Outcome = {
solverContext = new PreproSolverContext(new Z3SolverContext(SolverConfig.default.copy(debug = true,
smtEncoding = SMTEncoding.FunArrays)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class TestRewriterWithOOPSLA19
with TestSymbStateRewriterVariant with TestSymbStateRewriterSequence with TestSymbStateRewriterSet
with TestSymbStateRewriterStr with TestSymbStateRewriterTuple with TestPropositionalOracle with TestSparseOracle
with TestUninterpretedConstOracle with TestSymbStateRewriterApalache with TestSymbStateRewriterMkSeq
with TestDefaultValueFactory {
with TestDefaultValueFactory with TestSymbStateRewriterRepeat {
override protected def withFixture(test: OneArgTest): Outcome = {
solverContext = new PreproSolverContext(new Z3SolverContext(SolverConfig.default.copy(debug = true,
smtEncoding = SMTEncoding.OOPSLA19)))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package at.forsyte.apalache.tla.bmcmt

import at.forsyte.apalache.infra.passes.options.SMTEncoding
import at.forsyte.apalache.tla.lir._
import at.forsyte.apalache.tla.types.tla

trait TestSymbStateRewriterRepeat extends RewriterBase {
test("""Repeat(LET Op(a,i) == a + 1 IN Op, 5, 0) = 5""") { rewriterType: SMTEncoding =>
// Op(x, i) == x + 1
val opT = OperT1(Seq(IntT1, IntT1), IntT1)

val plusOneEx = tla.plus(tla.name("x", IntT1), tla.int(1))
val plusOneOper = tla.decl("Op", plusOneEx, tla.param("x", IntT1), tla.param("i", IntT1))
val letEx = tla.letIn(tla.name(plusOneOper.name, opT), plusOneOper)
val repeatEx = tla.repeat(letEx, 5, tla.int(0))

val rewriter = create(rewriterType)
var state = new SymbState(repeatEx, arena, Binding())
state = rewriter.rewriteUntilDone(state)
val asCell = state.asCell

// compare the value
val eqn = tla.eql(asCell.toBuilder, tla.int(5))
assertTlaExAndRestore(rewriter, state.setRex(eqn))
}

test("""Repeat(LET Op(a,i) == a + i IN Op, 5, 0) = 15""") { rewriterType: SMTEncoding =>
// Op(x, i) == x + i
val opT = OperT1(Seq(IntT1, IntT1), IntT1)

val plusiEx = tla.plus(tla.name("x", IntT1), tla.name("i", IntT1))
val plusiOper = tla.decl("Op", plusiEx, tla.param("x", IntT1), tla.param("i", IntT1))
val letEx = tla.letIn(tla.name(plusiOper.name, opT), plusiOper)
val repeatEx = tla.repeat(letEx, 5, tla.int(0))

val rewriter = create(rewriterType)
var state = new SymbState(repeatEx, arena, Binding())
state = rewriter.rewriteUntilDone(state)
val asCell = state.asCell

// compare the value
val eqn = tla.eql(asCell.toBuilder, tla.int(15))
assertTlaExAndRestore(rewriter, state.setRex(eqn))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ object BuilderCallByName {
ApalacheOper.mkSeq.name -> ApalacheOper.mkSeq,
ApalacheOper.foldSet.name -> ApalacheOper.foldSet,
ApalacheOper.foldSeq.name -> ApalacheOper.foldSeq,
ApalacheOper.repeat.name -> ApalacheOper.repeat,
ApalacheInternalOper.selectInSet.name -> ApalacheInternalOper.selectInSet,
ApalacheInternalOper.selectInFun.name -> ApalacheInternalOper.selectInFun,
ApalacheInternalOper.storeInSet.name -> ApalacheInternalOper.storeInSet,
Expand Down Expand Up @@ -375,6 +376,15 @@ object BuilderCallByName {
case ApalacheOper.foldSeq =>
val Seq(f, v, s) = args
tla.foldSeq(f, v, s)
case ApalacheOper.repeat =>
val Seq(f, n, x) = args
val nEx: TlaEx = n
nEx match {
case ValEx(TlaInt(n)) =>
tla.repeat(f, n, x)
// should never happen, for case-completeness
case _ => throw new JsonDeserializationError(s"${oper.name} requires an integer argument.")
}
case ApalacheInternalOper.selectInSet =>
val Seq(x, s) = args
tla.selectInSet(x, s)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ object StandardLibrary {
("Apalache", "ApaFoldSet") -> ApalacheOper.foldSet,
("__apalache_folds", "__ApalacheFoldSet") -> ApalacheOper.foldSet,
("Apalache", "ApaFoldSeqLeft") -> ApalacheOper.foldSeq,
("Apalache", "Repeat") -> ApalacheOper.repeat,
// Variants
("Variants", "Variant") -> VariantOper.variant,
("Variants", "VariantFilter") -> VariantOper.variantFilter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,12 @@ class ToEtcExpr(
mkExRefApp(opsig, Seq(v, variantEx, defaultEx))

// ******************************************** Apalache **************************************************
case OperEx(ApalacheOper.repeat, op, bound, base) =>
val a = varPool.fresh
// ((a, Int) => a, Int, a) => a
val opsig = OperT1(Seq(OperT1(Seq(a, IntT1), a), IntT1, a), a)
mkExRefApp(opsig, Seq(op, bound, base))

case OperEx(ApalacheOper.mkSeq, len, ctor) =>
val a = varPool.fresh
// (Int, (Int => a)) => Seq(a)
Expand Down
Loading

0 comments on commit b54d097

Please sign in to comment.