Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce the Repeat Apalache operator #2927

Merged
merged 9 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions docs/src/lang/apalache-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,44 @@ IN
The operators `ApaFoldSet` and `ApaFoldSeqLeft` are explained in more detail in a dedicated section [here](../apalache/principles/folds.md).

----------------------------------------------------------------------------
<a name="Repeat"></a>
## Operator iteration

**Notation:** `Repeat(Op, N, x)`

**LaTeX notation:** `Repeat(Op, N, x)`

**Arguments:** Three arguments: An operator `Op`, an iteration counter `N` (a nonnegative constant integer expression), and an
initial value `x`.

**Apalache type:** `((a, Int), Int, a) => a`, for some type `a`.

**Effect:** For a given constant bound `N`, computes the value
`F(F(F(F(x,1), 2), ...), N)`. If `N=0` it evaluates to `x`.

```tla
Repeat(Op, N, x) ==
ApaFoldSeqLeft(Op, x, MkSeq(N, LAMBDA i:i))
```

Apalache implements a more efficient encoding of this operator than the default one.

**Determinism:** Deterministic.

**Errors:**
If any argument is ill-typed, or `N` is not a nonnegative constant integer expression, Apalache reports an error.

**Example in TLA+:**

```tla
Op(a) == a + 1
LET OpModified(a,i) == Op(i)
IN Repeat(OpModified, 0, 5) = 5 \* TRUE

Op2(a,i) == a + i
Repeat(Op2, 0, 5) = 15 \* TRUE
```
----------------------------------------------------------------------------
<a name="SetAsFun"></a>

## Convert a set of pairs to a function
Expand Down
13 changes: 13 additions & 0 deletions src/tla/Apalache.tla
Original file line number Diff line number Diff line change
Expand Up @@ -152,4 +152,17 @@ ApaFoldSeqLeft(__Op(_,_), __v, __seq) ==
THEN __v
ELSE ApaFoldSeqLeft(__Op, __Op(__v, Head(__seq)), Tail(__seq))

(**
* The repetition operator, used to consecutively apply an operator, starting from
* an initial value.
*
* @type: ((a, Int) => a, Int, a) => a;
*)
RECURSIVE Repeat(_,_,_)
Repeat(__F(_,_), __N, __x) ==
\* This is the TLC implementation. Apalache does it differently.
IF __N <= 0
THEN __x
ELSE __F(Repeat(__F, __N - 1, __x), __N)

===============================================================================
66 changes: 66 additions & 0 deletions test/tla/Repeat.tla
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
------------------------ MODULE Repeat ------------------------------------
EXTENDS Apalache, Integers

Inv1 ==
LET Op(a, i) == a + 1
IN Repeat(Op, 5, 0) = 5

Inv2 ==
LET Op(a, i) == a + i
IN Repeat(Op, 5, 0) = 15

\* Cyclical Op: \E k: Op^k = Id
Inv3 ==
LET Op(a,i) == (a + i) % 3
IN LET
v == 1
x0 == Repeat(Op, 0, v)
x3 == Repeat(Op, 3, v)
x6 == Repeat(Op, 6, v)
IN
/\ v = x0
/\ x0 = x3
/\ x3 = x6

\* Idempotent Op: Op^2 = Op
Inv4 ==
LET
\* @type: (Set(Set(Int)), Int) => Set(Set(Int));
Op(a, i) == {UNION a}
IN LET
v == {{1}, {2}, {3,4}}
x1 == Repeat(Op, 1, v)
x2 == Repeat(Op, 2, v)
x3 == Repeat(Op, 3, v)
IN
/\ v /= x1
/\ x1 = x2
/\ x2 = x3

\* Nilpotent Op: \E k: Op^k = 0
Inv5 ==
LET
\* @type: (Set(Int), Int) => Set(Int);
Op(a, i) == a \ { x \in a: \A y \in a: x <= y }
IN LET
v == {1,2,3}
x1 == Repeat(Op, 2, v)
x2 == Repeat(Op, 3, v)
x3 == Repeat(Op, 4, v)
IN
/\ x1 /= x2
/\ x2 = x3
/\ x3 = {}


Init == TRUE
Next == TRUE

Inv ==
/\ Inv1
/\ Inv2
/\ Inv3
/\ Inv4
/\ Inv5

===============================================================================
10 changes: 10 additions & 0 deletions test/tla/RepeatBad.tla
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
------------------------ MODULE RepeatBad ------------------------------------
Kukovec marked this conversation as resolved.
Show resolved Hide resolved
EXTENDS Apalache, Integers

Inv ==
LET Op(a, i) == a + 1
IN \E x \in {5} : Repeat(Op, x, 0) = 5

Init == TRUE
Next == TRUE
===============================================================================
18 changes: 18 additions & 0 deletions test/tla/cli-integration-tests.md
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,24 @@ The outcome is: NoError
EXITCODE: OK
```

### check Repeat succeeds

```sh
$ apalache-mc check --inv=Inv --length=0 Repeat.tla | sed 's/I@.*//'
...
The outcome is: NoError
...
EXITCODE: OK
```

### check RepeatBad fails

```sh
$ apalache-mc check --inv=Inv --length=0 RepeatBad.tla | sed 's/I@.*//'
...
EXITCODE: ERROR (255)
```

### check Counter.tla errors (array-encoding)

```sh
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,13 @@ class SymbStateRewriterImpl(
-> List(new LabelRule(this)),
key(OperEx(ApalacheOper.gen, tla.int(2)))
-> List(new GenRule(this)),
// folds and MkSeq
// folds, repeat and MkSeq
key(OperEx(ApalacheOper.foldSet, tla.name("A"), tla.name("v"), tla.name("S")))
-> List(new FoldSetRule(this, renaming)),
key(OperEx(ApalacheOper.foldSeq, tla.name("A"), tla.name("v"), tla.name("s")))
-> List(new FoldSeqRule(this, renaming)),
key(OperEx(ApalacheOper.repeat, tla.name("Op"), tla.int(10), tla.name("s")))
-> List(new RepeatRule(this, renaming)),
key(OperEx(ApalacheOper.mkSeq, tla.int(10), tla.name("A")))
-> List(new MkSeqRule(this, renaming)),
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package at.forsyte.apalache.tla.bmcmt.rules

import at.forsyte.apalache.tla.bmcmt._
import at.forsyte.apalache.tla.lir._
import at.forsyte.apalache.tla.lir.oper.ApalacheOper
import at.forsyte.apalache.tla.lir.transformations.impl.IdleTracker
import at.forsyte.apalache.tla.lir.transformations.standard.IncrementalRenaming
import at.forsyte.apalache.tla.lir.values.TlaInt
import at.forsyte.apalache.tla.pp.Inliner
import at.forsyte.apalache.tla.types.tla

/**
* Rewriting rule for Repeat. This rule is similar to [[FoldSeqRule]].
Kukovec marked this conversation as resolved.
Show resolved Hide resolved
*
* @author
* Jure Kukovec
*/
class RepeatRule(rewriter: SymbStateRewriter, renaming: IncrementalRenaming) extends RewritingRule {

override def isApplicable(symbState: SymbState): Boolean = {
symbState.ex match {
case OperEx(ApalacheOper.repeat, LetInEx(NameEx(appName), TlaOperDecl(operName, params, _)), _, _) =>
appName == operName && params.size == 2
case _ => false
}
}

override def apply(state: SymbState): SymbState = state.ex match {
// assume isApplicable
case ex @ OperEx(ApalacheOper.repeat, LetInEx(NameEx(_), opDecl), boundEx, baseEx) =>
boundEx match {
case ValEx(TlaInt(n)) if n.isValidInt && n.toInt >= 0 =>
// rewrite baseEx to its final cell form
val baseState = rewriter.rewriteUntilDone(state.setRex(baseEx))

// We need the type signature. The signature of Repeat is
// \A a : ((a,Int) => a, Int, a) => a
// so the operator type must be (a,Int) => a
val a = TlaType1.fromTypeTag(baseEx.typeTag)
val opT = OperT1(Seq(a, IntT1), a)
// sanity check
TlaType1.fromTypeTag(opDecl.typeTag) match {
case `opT` => // all good
case badType =>
val msg = s"FoldSet argument ${opDecl.name} should have the tag $opT, found $badType."
throw new TypingException(msg, opDecl.ID)
}

// expressions are transient, we don't need tracking
val inliner = new Inliner(new IdleTracker, renaming)
// We can make the scope directly, since InlinePass already ensures all is well.
val seededScope: Inliner.Scope = Map(opDecl.name -> opDecl)

// To implement the Repeat rule, we generate a sequence of cells.
// At each step, we perform one application of the operator
// defined by `opDecl` to the previous partial result,
// and pass the iteration index as the second parameter
(1 to n.toInt).foldLeft(baseState) { case (partialState, i) =>
// partialState currently holds the cell representing the previous application step
val oldResultCell = partialState.asCell

// First, we inline the operator application, with cell names as parameters
val appEx = tla.appOp(
tla.name(opDecl.name, opT),
oldResultCell.toBuilder,
tla.int(i),
)

val inlinedEx = inliner.transform(seededScope)(appEx)
rewriter.rewriteUntilDone(partialState.setRex(inlinedEx))
}
case _ =>
throw new RewriterException("Apalache!Repeat expects a constant positive integer. Found: " + boundEx, ex)
}
case _ =>
throw new RewriterException("%s is not applicable".format(getClass.getSimpleName), state.ex)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ object TlaExUtil {
findRec(localBody)
baseExAndCollectionEx.foreach(findRec)

// ignore the names in the auxiliary let-in definition
case OperEx(ApalacheOper.repeat, LetInEx(_, TlaOperDecl(_, _, localBody)), boundAndBaseEx @ _*) =>
findRec(localBody)
boundAndBaseEx.foreach(findRec)

// ignore the names in the auxiliary let-in definition
case OperEx(ApalacheOper.mkSeq, len, LetInEx(_, TlaOperDecl(_, _, localBody))) =>
findRec(localBody)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class TestRewriterWithArrays
with TestSymbStateRewriterPowerset with TestSymbStateRewriterRecord with TestSymbStateRewriterSequence
with TestSymbStateRewriterSet with TestSymbStateRewriterStr with TestSymbStateRewriterTuple
with TestPropositionalOracle with TestSparseOracle with TestUninterpretedConstOracle
with TestSymbStateRewriterApalache with TestSymbStateRewriterMkSeq {
with TestSymbStateRewriterApalache with TestSymbStateRewriterMkSeq with TestSymbStateRewriterRepeat {
override protected def withFixture(test: OneArgTest): Outcome = {
solverContext = new PreproSolverContext(new Z3SolverContext(SolverConfig.default.copy(debug = true,
smtEncoding = SMTEncoding.Arrays)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class TestRewriterWithFunArrays
with TestSymbStateRewriterPowerset with TestSymbStateRewriterRecord with TestSymbStateRewriterSequence
with TestSymbStateRewriterSet with TestSymbStateRewriterStr with TestSymbStateRewriterTuple
with TestPropositionalOracle with TestSparseOracle with TestUninterpretedConstOracle
with TestSymbStateRewriterApalache with TestSymbStateRewriterMkSeq {
with TestSymbStateRewriterApalache with TestSymbStateRewriterMkSeq with TestSymbStateRewriterRepeat {
override protected def withFixture(test: OneArgTest): Outcome = {
solverContext = new PreproSolverContext(new Z3SolverContext(SolverConfig.default.copy(debug = true,
smtEncoding = SMTEncoding.FunArrays)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class TestRewriterWithOOPSLA19
with TestSymbStateRewriterVariant with TestSymbStateRewriterSequence with TestSymbStateRewriterSet
with TestSymbStateRewriterStr with TestSymbStateRewriterTuple with TestPropositionalOracle with TestSparseOracle
with TestUninterpretedConstOracle with TestSymbStateRewriterApalache with TestSymbStateRewriterMkSeq
with TestDefaultValueFactory {
with TestDefaultValueFactory with TestSymbStateRewriterRepeat {
override protected def withFixture(test: OneArgTest): Outcome = {
solverContext = new PreproSolverContext(new Z3SolverContext(SolverConfig.default.copy(debug = true,
smtEncoding = SMTEncoding.OOPSLA19)))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package at.forsyte.apalache.tla.bmcmt

import at.forsyte.apalache.infra.passes.options.SMTEncoding
import at.forsyte.apalache.tla.lir._
import at.forsyte.apalache.tla.types.tla

trait TestSymbStateRewriterRepeat extends RewriterBase {
test("""Repeat(LET Op(a,i) == a + 1 IN Op, 5, 0) = 5""") { rewriterType: SMTEncoding =>
// Op(x, i) == x + 1
val opT = OperT1(Seq(IntT1, IntT1), IntT1)

val plusOneEx = tla.plus(tla.name("x", IntT1), tla.int(1))
val plusOneOper = tla.decl("Op", plusOneEx, tla.param("x", IntT1), tla.param("i", IntT1))
val letEx = tla.letIn(tla.name(plusOneOper.name, opT), plusOneOper)
val repeatEx = tla.repeat(letEx, 5, tla.int(0))

val rewriter = create(rewriterType)
var state = new SymbState(repeatEx, arena, Binding())
state = rewriter.rewriteUntilDone(state)
val asCell = state.asCell

// compare the value
val eqn = tla.eql(asCell.toBuilder, tla.int(5))
assertTlaExAndRestore(rewriter, state.setRex(eqn))
}

test("""Repeat(LET Op(a,i) == a + i IN Op, 5, 0) = 15""") { rewriterType: SMTEncoding =>
// Op(x, i) == x + i
val opT = OperT1(Seq(IntT1, IntT1), IntT1)

val plusiEx = tla.plus(tla.name("x", IntT1), tla.name("i", IntT1))
val plusiOper = tla.decl("Op", plusiEx, tla.param("x", IntT1), tla.param("i", IntT1))
val letEx = tla.letIn(tla.name(plusiOper.name, opT), plusiOper)
val repeatEx = tla.repeat(letEx, 5, tla.int(0))

val rewriter = create(rewriterType)
var state = new SymbState(repeatEx, arena, Binding())
state = rewriter.rewriteUntilDone(state)
val asCell = state.asCell

// compare the value
val eqn = tla.eql(asCell.toBuilder, tla.int(15))
assertTlaExAndRestore(rewriter, state.setRex(eqn))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ object BuilderCallByName {
ApalacheOper.mkSeq.name -> ApalacheOper.mkSeq,
ApalacheOper.foldSet.name -> ApalacheOper.foldSet,
ApalacheOper.foldSeq.name -> ApalacheOper.foldSeq,
ApalacheOper.repeat.name -> ApalacheOper.repeat,
ApalacheInternalOper.selectInSet.name -> ApalacheInternalOper.selectInSet,
ApalacheInternalOper.selectInFun.name -> ApalacheInternalOper.selectInFun,
ApalacheInternalOper.storeInSet.name -> ApalacheInternalOper.storeInSet,
Expand Down Expand Up @@ -375,6 +376,15 @@ object BuilderCallByName {
case ApalacheOper.foldSeq =>
val Seq(f, v, s) = args
tla.foldSeq(f, v, s)
case ApalacheOper.repeat =>
val Seq(f, n, x) = args
val nEx: TlaEx = n
nEx match {
case ValEx(TlaInt(n)) =>
tla.repeat(f, n, x)
// should never happen, for case-completeness
case _ => throw new JsonDeserializationError(s"${oper.name} requires an integer argument.")
}
case ApalacheInternalOper.selectInSet =>
val Seq(x, s) = args
tla.selectInSet(x, s)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ object StandardLibrary {
("Apalache", "ApaFoldSet") -> ApalacheOper.foldSet,
("__apalache_folds", "__ApalacheFoldSet") -> ApalacheOper.foldSet,
("Apalache", "ApaFoldSeqLeft") -> ApalacheOper.foldSeq,
("Apalache", "Repeat") -> ApalacheOper.repeat,
// Variants
("Variants", "Variant") -> VariantOper.variant,
("Variants", "VariantFilter") -> VariantOper.variantFilter,
Expand Down
6 changes: 6 additions & 0 deletions tla-typechecker/