diff --git a/UNRELEASED.md b/UNRELEASED.md index b50505ed4c..9fe11d3671 100644 --- a/UNRELEASED.md +++ b/UNRELEASED.md @@ -13,12 +13,17 @@ ### Features +* Model checker: receiving the types from with the type checker Snowcat, see #668 and #350 +* Type checker: the old Apalache type annotations are no longer supported, see #668 * Type checker: tagging all expressions with the reconstructed types, see #608 -* Type checker: experimental option `check --with-snowcat`, see #632 +* Type checker: handling TLA+ labels like `lab("a", "b") :: e`, see #653 +* Preprocessing: handling the general case of EXCEPT, see #647 ### Changed * Preprocessing: massive refactoring of the passes to support types. This may have introduced unexpected bugs. +* Model checker: translation rules for records and functions have been modified, in order to support new types. Bugs to + be expected. ### Known issues diff --git a/mod-tool/src/main/scala/at/forsyte/apalache/tla/Tool.scala b/mod-tool/src/main/scala/at/forsyte/apalache/tla/Tool.scala index d7f0d40cbe..1b1852bb0d 100644 --- a/mod-tool/src/main/scala/at/forsyte/apalache/tla/Tool.scala +++ b/mod-tool/src/main/scala/at/forsyte/apalache/tla/Tool.scala @@ -170,7 +170,7 @@ object Tool extends App with LazyLogging { executor.options.set("checker.noDeadlocks", check.noDeadlocks) executor.options.set("checker.algo", check.algo) // this option enables the new type checker in the pipeline - executor.options.set("typechecker.snowcatOn", check.withSnowcat) + executor.options.set("typechecker.snowcatOn", true) // for now, enable polymorphic types. We probably want to disable this option for the type checker executor.options.set("typechecker.inferPoly", true) diff --git a/mod-tool/src/main/scala/at/forsyte/apalache/tla/tooling/opt/CheckCmd.scala b/mod-tool/src/main/scala/at/forsyte/apalache/tla/tooling/opt/CheckCmd.scala index bfaf9805e4..5530f2c7e4 100644 --- a/mod-tool/src/main/scala/at/forsyte/apalache/tla/tooling/opt/CheckCmd.scala +++ b/mod-tool/src/main/scala/at/forsyte/apalache/tla/tooling/opt/CheckCmd.scala @@ -42,6 +42,4 @@ class CheckCmd extends Command(name = "check", description = "Check a TLA+ speci "pre-check, whether a transition is disabled, and discard it, to make SMT queries smaller, default: true") var noDeadlocks: Boolean = opt[Boolean](name = "no-deadlock", default = true, description = "do not check for deadlocks, default: true") - var withSnowcat: Boolean = - opt[Boolean](name = "with-snowcat", default = false, description = "use the new type checker Snowcat") } diff --git a/test/tla/Assignments20200309.tla b/test/tla/Assignments20200309.tla index 4e5e52795c..3aaabf7cd7 100644 --- a/test/tla/Assignments20200309.tla +++ b/test/tla/Assignments20200309.tla @@ -1,5 +1,8 @@ ----- MODULE Assignments20200309 ----- -VARIABLE a +VARIABLE + \* @type: Int; + a + \* this specification fails, as it has no expression \* that can be treated as an assignment Init == TRUE diff --git a/test/tla/Bug20190118.tla b/test/tla/Bug20190118.tla index 5265b1c94f..e88966db9e 100644 --- a/test/tla/Bug20190118.tla +++ b/test/tla/Bug20190118.tla @@ -5,8 +5,6 @@ (* BMCMT extensions *) RM == {"r1", "r2"} -a <: b == a \* a type annotation - \* new: a message type MT == [type |-> STRING, rm |-> STRING] (* END OF BMCMT extensions *) @@ -23,7 +21,7 @@ VARIABLES Message == {[type |-> t, rm |-> r]: t \in {"Prepared"}, r \in RM } \cup - {([type |-> t] <: MT) : t \in {"Commit", "Abort"} } + {[type |-> t] : t \in {"Commit", "Abort"} } Init == /\ rmState \in [RM -> {"working", "prepared"}] @@ -31,11 +29,11 @@ Init == /\ rmState["r1"] = "working" /\ rmState["r2"] = "prepared" /\ tmPrepared = {"r2"} - /\ msgs = {[type |-> "Prepared", rm |-> "r2"] <: MT} + /\ msgs = {[type |-> "Prepared", rm |-> "r2"]} TMCommit == /\ tmPrepared = RM - /\ msgs' = msgs \cup {[type |-> "Commit"] <: MT} + /\ msgs' = msgs \cup {[type |-> "Commit"]} /\ UNCHANGED <> RMPrepare(rm) == @@ -47,6 +45,6 @@ RMPrepare(rm) == Next == TMCommit \/ RMPrepare("r1") ----------------------------------------------------------------------------- \* this invariant cannot be violated in one step -Inv == ([type |-> "Commit"] <: MT) \notin msgs +Inv == [type |-> "Commit"] \notin msgs ============================================================================= diff --git a/test/tla/Bug20190921.tla b/test/tla/Bug20190921.tla index 59a6b147d7..fd64e4338f 100644 --- a/test/tla/Bug20190921.tla +++ b/test/tla/Bug20190921.tla @@ -3,9 +3,13 @@ EXTENDS Integers \* constants and variables should propagate \* in the transformations -CONSTANT N +CONSTANT + \* @type: Int; + N -VARIABLE x +VARIABLE + \* @type: Int; + x CInit == N \in 1..10 diff --git a/test/tla/Bug20200306.tla b/test/tla/Bug20200306.tla index 07bc23f525..1226801160 100644 --- a/test/tla/Bug20200306.tla +++ b/test/tla/Bug20200306.tla @@ -1,5 +1,7 @@ ----- MODULE Bug20200306 ----- -VARIABLE a +VARIABLE + \* @type: Int; + a Init == a = 1 Next == a' = a Inv == FALSE diff --git a/test/tla/Callback.tla b/test/tla/Callback.tla index e4cc8eec4b..052ca8a8b3 100644 --- a/test/tla/Callback.tla +++ b/test/tla/Callback.tla @@ -1,7 +1,9 @@ ----------------------------- MODULE Callback ---------------------------- EXTENDS Integers -VARIABLES x +VARIABLES + \* @type: Int; + x Pick(Cb(_)) == \E i \in 1..10: diff --git a/test/tla/Config.tla b/test/tla/Config.tla index b78cd9f047..b06675df69 100644 --- a/test/tla/Config.tla +++ b/test/tla/Config.tla @@ -2,7 +2,9 @@ (* a specification to check whether the configuration files are parsed *) EXTENDS Integers -VARIABLES x +VARIABLES + \* @type: Int; + x \* the default init Init == diff --git a/test/tla/ConfigParams.tla b/test/tla/ConfigParams.tla index f6291c9725..eacd19b6d6 100644 --- a/test/tla/ConfigParams.tla +++ b/test/tla/ConfigParams.tla @@ -1,13 +1,27 @@ --------------------------- MODULE ConfigParams ------------------------------- CONSTANTS + \* @type: Int; MyInt, + \* @type: Str; MyStr, + \* TODO: use a model type here + \* when #570 is closed: https://github.com/informalsystems/apalache/issues/570 + \* @type: Str; MyModelValue1, + \* @type: Str; MyModelValue2, + \* @type: Set(Int); MySet VARIABLES - x, y, z, w + \* @type: Int; + x, + \* @type: Str; + y, + \* @type: Str; + z, + \* @type: Set(Int); + w Init == /\ x = MyInt diff --git a/test/tla/ConfigReplacements.tla b/test/tla/ConfigReplacements.tla index 068d3d938c..1f4d1c52b6 100644 --- a/test/tla/ConfigReplacements.tla +++ b/test/tla/ConfigReplacements.tla @@ -1,5 +1,7 @@ ----------------------- MODULE ConfigReplacements ----------------------------- -VARIABLES x +VARIABLES + \* @type: Int; + x Value == 0 diff --git a/test/tla/ConfigUnsorted.tla b/test/tla/ConfigUnsorted.tla index a6196d4238..616a0a13eb 100644 --- a/test/tla/ConfigUnsorted.tla +++ b/test/tla/ConfigUnsorted.tla @@ -1,6 +1,8 @@ -------------------------- MODULE ConfigUnsorted ---------------------------------------- (* A specification that introduces preprocessing issues *) -VARIABLES x +VARIABLES + \* @type: Int; + x A == 1 B == 2 diff --git a/test/tla/Counter.tla b/test/tla/Counter.tla index f443637ee1..2b86ac0ce2 100644 --- a/test/tla/Counter.tla +++ b/test/tla/Counter.tla @@ -1,7 +1,9 @@ ------ MODULE Counter ------ EXTENDS Naturals -VARIABLES x +VARIABLES + \* @type: Int; + x Init == x \in 1..1000 diff --git a/test/tla/EWD840.tla b/test/tla/EWD840.tla index 5df2f59560..94fff731be 100644 --- a/test/tla/EWD840.tla +++ b/test/tla/EWD840.tla @@ -10,7 +10,15 @@ EXTENDS Naturals N == 4 (*ASSUME NAssumption == N \in Nat \ {0}*) -VARIABLES active, color, tpos, tcolor +VARIABLES + \* @type: Int -> Bool; + active, + \* @type: Int -> Str; + color, + \* @type: Int; + tpos, + \* @type: Str; + tcolor Nodes == 0 .. N-1 Color == {"white", "black"} diff --git a/test/tla/ExistsAsValue.tla b/test/tla/ExistsAsValue.tla index fed281fddb..8bb9795bb6 100644 --- a/test/tla/ExistsAsValue.tla +++ b/test/tla/ExistsAsValue.tla @@ -1,6 +1,8 @@ ------------------------ MODULE ExistsAsValue ------------------------- \* a test for the issue #148 -VARIABLES x +VARIABLES + \* @type: Bool; + x Init == x = TRUE diff --git a/test/tla/Fix365_ExistsSubset3.tla b/test/tla/Fix365_ExistsSubset3.tla index 962fa5cf95..4376c8d6d3 100644 --- a/test/tla/Fix365_ExistsSubset3.tla +++ b/test/tla/Fix365_ExistsSubset3.tla @@ -2,9 +2,6 @@ (* A tricky bug that happened in evidence handling *) EXTENDS Integers -\* old apalache annotations -a <: b == a - Proc == {"p1", "p2"} Rounds == { 0, 1, 2 } @@ -27,7 +24,7 @@ Next == LET \* @type: Set(Str); Y == { m.src: m \in msgs } IN \* the third ingredient of the bug - /\ msgs /= {} <: {MT} + /\ msgs /= {} /\ UNCHANGED msgs =============================================================================== diff --git a/test/tla/HandshakeWithTypes.tla b/test/tla/HandshakeWithTypes.tla index 1e1f3a8404..c06112f6f7 100644 --- a/test/tla/HandshakeWithTypes.tla +++ b/test/tla/HandshakeWithTypes.tla @@ -7,18 +7,22 @@ *) EXTENDS Integers -VARIABLES msgs, \* the set of all messages - iseqno, \* Initiator's sequence number - rseqno, \* Receiver's sequence number - istate, \* Initiator's state - rstate \* Receiver's state - -a <: b == a +VARIABLES + \* @type: Set([syn: Bool, ack: Bool, seqno: Int, ackno: Int]); + msgs, \* the set of all messages + \* @type: Int; + iseqno, \* Initiator's sequence number + \* @type: Int; + rseqno, \* Receiver's sequence number + \* @type: Str; + istate, \* Initiator's state + \* @type: Str; + rstate \* Receiver's state MT == [syn |-> BOOLEAN, ack |-> BOOLEAN, seqno |-> Int, ackno |-> Int] Init == - /\ msgs = {} <: {MT} + /\ msgs = {} /\ iseqno = 0 /\ rseqno = 0 /\ istate = "INIT" @@ -28,7 +32,7 @@ SendSyn == /\ istate = "INIT" /\ \E no \in Nat: /\ msgs' = msgs \union {[syn |-> TRUE, - ack |-> FALSE, seqno |-> no] <: MT} + ack |-> FALSE, seqno |-> no]} /\ iseqno' = no + 1 /\ istate' = "SYN-SENT" /\ UNCHANGED <> @@ -36,7 +40,7 @@ SendSyn == SendSynAck == /\ rstate = "LISTEN" /\ \E seqno, ackno \in Nat: - /\ ([syn |-> TRUE, ack |-> FALSE, seqno |-> seqno] <: MT) \in msgs + /\ [syn |-> TRUE, ack |-> FALSE, seqno |-> seqno] \in msgs /\ msgs' = msgs \union {[syn |-> TRUE, ack |-> TRUE, seqno |-> seqno + 1, ackno |-> ackno]} diff --git a/test/tla/HourClock.tla b/test/tla/HourClock.tla index 8d43d30baa..b7c700bb27 100644 --- a/test/tla/HourClock.tla +++ b/test/tla/HourClock.tla @@ -2,7 +2,10 @@ \* This is a local copy of the example from Specifying Systems: \* https://github.com/tlaplus/Examples/blob/master/specifications/SpecifyingSystems/RealTime/HourClock.tla EXTENDS Naturals -VARIABLE hr +VARIABLE + \* @type: Int; + hr + HCini == hr \in (1 .. 12) HCnxt == hr' = IF hr # 12 THEN hr + 1 ELSE 1 HC == HCini /\ [][HCnxt]_hr diff --git a/tla-assignments/src/test/resources/assignmentSolver/ITE_CASE.tla b/test/tla/ITE_CASE.tla similarity index 70% rename from tla-assignments/src/test/resources/assignmentSolver/ITE_CASE.tla rename to test/tla/ITE_CASE.tla index 87d945ae10..5be420bfb1 100644 --- a/tla-assignments/src/test/resources/assignmentSolver/ITE_CASE.tla +++ b/test/tla/ITE_CASE.tla @@ -1,5 +1,9 @@ ------------------------------ MODULE ITE_CASE ------------------------------ -VARIABLES x, y +VARIABLES + \* @type: Int; + x, + \* @type: Int; + y S == {1,2,3} @@ -11,6 +15,9 @@ ITE(p, et, ee) == /\ IF p THEN et ELSE ee Next == ITE( x = y', x' = 2, x' \in S ) -Spec == /\ Init /\ [][Next]_<> +\* @type: <>; +vars == <> + +Spec == /\ Init /\ [][Next]_vars ============================================================================= diff --git a/test/tla/Inline.tla b/test/tla/Inline.tla index 8b2dcc89ad..7395a609a8 100644 --- a/test/tla/Inline.tla +++ b/test/tla/Inline.tla @@ -1,5 +1,7 @@ ---------------------------- MODULE Inline ------------------------------- -VARIABLE x +VARIABLE + \* @type: Int; + x A == 3 diff --git a/test/tla/NatCounter.tla b/test/tla/NatCounter.tla index 7bd91c1f3d..d1242e46c4 100644 --- a/test/tla/NatCounter.tla +++ b/test/tla/NatCounter.tla @@ -1,7 +1,9 @@ ----------------------------- MODULE NatCounter ------------------------ EXTENDS Naturals -VARIABLE x +VARIABLE + \* @type: Int; + x Init == x = 3 diff --git a/test/tla/NeedForTypesWithTypes.tla b/test/tla/NeedForTypesWithTypes.tla index 27abffcd2a..0c21f690ce 100644 --- a/test/tla/NeedForTypesWithTypes.tla +++ b/test/tla/NeedForTypesWithTypes.tla @@ -4,20 +4,24 @@ *) EXTENDS Integers, Sequences, FiniteSets -CONSTANTS InSet \* an input set -VARIABLES Left, \* a storage for the yet untransformed elements - OutSeq \* the output sequence +CONSTANTS + \* @type: Set(Int); + InSet \* an input set -a <: b == a +VARIABLES + \* @type: Set(Int); + Left, \* a storage for the yet untransformed elements + \* @type: Seq(Int); + OutSeq \* the output sequence ConstInit == InSet = 1..4 Init == - /\ OutSeq = << >> <: Seq(Int) + /\ OutSeq = << >> /\ Left = InSet Next == - IF Left = {} <: {Int} + IF Left = {} THEN UNCHANGED <> ELSE \E x \in Left: /\ OutSeq' = Append(OutSeq, x) diff --git a/test/tla/NonNullaryLet.tla b/test/tla/NonNullaryLet.tla index 6a285f210b..871f446bbd 100644 --- a/test/tla/NonNullaryLet.tla +++ b/test/tla/NonNullaryLet.tla @@ -1,10 +1,12 @@ --------------- MODULE NonNullaryLet ---------------- EXTENDS Integers -VARIABLE n +VARIABLE + \* @type: Int; + n Foo == LET r(x) == TRUE IN r(1) Init == n = 0 Next == Foo /\ n' = 1 -=========================================== \ No newline at end of file +=========================================== diff --git a/test/tla/Paxos.tla b/test/tla/Paxos.tla index 686cbc2320..95b2def603 100644 --- a/test/tla/Paxos.tla +++ b/test/tla/Paxos.tla @@ -25,11 +25,6 @@ Value == {0, 1} Acceptor == {"a1", "a2"} Quorum == {{"a1", "a2"}} -a <: b == a -MT == [type |-> STRING, bal |-> Int, - mbal |-> Int, acc |-> STRING, - val |-> Int, mval |-> Int] - (* Acceptor == {"a1", "a2", "a3"} Quorum == {{"a1", "a2"}, {"a2", "a3"}, {"a1", "a3"}} @@ -81,11 +76,16 @@ Message == [type : {"1a"}, bal : Ballot] \cup [type : {"2a"}, bal : Ballot, val : Value] \cup [type : {"2b"}, acc : Acceptor, bal : Ballot, val : Value] ----------------------------------------------------------------------------- -VARIABLE maxBal, - maxVBal, \* <> is the vote with the largest - maxVal, \* ballot number cast by a; it equals <<-1, None>> if +VARIABLE + \* @type: Str -> Int; + maxBal, + \* @type: Str -> Int; + maxVBal, \* <> is the vote with the largest + \* @type: Str -> Int; + maxVal, \* ballot number cast by a; it equals <<-1, None>> if \* a has not cast any vote. - msgs \* The set of all messages that have been sent. + \* @type: Set([type: Str, bal: Int, acc: Str, mbal: Int, mval: Int, val: Int]); + msgs \* The set of all messages that have been sent. (***************************************************************************) (* NOTE: *) @@ -124,8 +124,13 @@ TypeOK == /\ maxBal \in [Acceptor -> Ballot \cup {-1}] (* the array `votes' describing the votes cast by the acceptors is defined *) (* as follows. *) (***************************************************************************) + +\* the original specification does not have this definition, we need it for types +\* @type: (Int, Int) => <>; +pair(i, j) == <> + votes == [a \in Acceptor |-> - {<> : m \in {mm \in msgs: /\ mm.type = "2b" + {pair(m.bal, m.val) : m \in {mm \in msgs: /\ mm.type = "2b" /\ mm.acc = a (* BMCMT check, let's add smth like HASFIELD(mm, "val") *) /\ mm.val = mm.val @@ -135,13 +140,13 @@ votes == [a \in Acceptor |-> Init == /\ maxBal = [a \in Acceptor |-> -1] /\ maxVBal = [a \in Acceptor |-> -1] /\ maxVal = [a \in Acceptor |-> None] - /\ msgs = {} <: {MT} + /\ msgs = {} (***************************************************************************) (* The actions. We begin with the subaction (an action that will be used *) (* to define the actions that make up the next-state action. *) (***************************************************************************) -Send(m) == /\ msgs' = msgs \cup {m <: MT} +Send(m) == /\ msgs' = msgs \cup {m} (***************************************************************************) @@ -151,7 +156,7 @@ Send(m) == /\ msgs' = msgs \cup {m <: MT} (* m with m.type = "1a") that begins ballot b. *) (***************************************************************************) Phase1a(b) == /\ Send([type |-> "1a", bal |-> b]) - /\ UNCHANGED <> + /\ UNCHANGED <> (***************************************************************************) (* Upon receipt of a ballot b phase 1a message, acceptor a can perform a *) @@ -165,7 +170,7 @@ Phase1b(a) == /\ \E m \in msgs : /\ maxBal' = [maxBal EXCEPT ![a] = m.bal] /\ Send([type |-> "1b", acc |-> a, bal |-> m.bal, mbal |-> maxVBal[a], mval |-> maxVal[a]]) - /\ UNCHANGED <> + /\ UNCHANGED <> (***************************************************************************) (* The Phase2a(b, v) action can be performed by the ballot b leader if two *) @@ -195,12 +200,12 @@ Phase2a(b, v) == /\ m.bal = b} Q1bv == {m \in Q1b : m.mbal \geq 0} IN /\ \A a \in Q : \E m \in Q1b : m.acc = a - /\ \/ Q1bv = ({} <: {MT}) + /\ \/ Q1bv = {} \/ \E m \in Q1bv : /\ m.mval = v /\ \A mm \in Q1bv : m.mbal \geq mm.mbal /\ Send([type |-> "2a", bal |-> b, val |-> v]) - /\ UNCHANGED <> + /\ UNCHANGED <> (***************************************************************************) (* The Phase2b(a) action is performed by acceptor a upon receipt of a *) @@ -282,9 +287,10 @@ Inv == (*/\ TypeOK*) /\ (m.mbal \geq 0) => <> \in votes[m.acc] /\ (m.type = "2a") => /\ \E Q \in Quorum : - (*V!*)ShowsSafeAt(Q, m.bal, m.val) + ShowsSafeAt(Q, m.bal, m.val) /\ \A mm \in msgs : /\ mm.type = "2a" /\ mm.bal = m.bal => mm.val = m.val + (*/\ V!Inv*) ============================================================================ diff --git a/test/tla/Rec1.tla b/test/tla/Rec1.tla index 9adb7eadda..daeda40fb3 100644 --- a/test/tla/Rec1.tla +++ b/test/tla/Rec1.tla @@ -6,7 +6,9 @@ *) EXTENDS Integers -VARIABLES f +VARIABLES + \* @type: Int; + f RECURSIVE Fact(_) diff --git a/test/tla/Rec10.tla b/test/tla/Rec10.tla index 3b220c7cf3..c4ff862be6 100644 --- a/test/tla/Rec10.tla +++ b/test/tla/Rec10.tla @@ -6,7 +6,9 @@ *) EXTENDS Integers -VARIABLES f +VARIABLES + \* @type: Int; + f RECURSIVE Fact(_) diff --git a/test/tla/Rec11.tla b/test/tla/Rec11.tla index 98695f0124..cd90381af8 100644 --- a/test/tla/Rec11.tla +++ b/test/tla/Rec11.tla @@ -6,7 +6,9 @@ *) EXTENDS Integers -VARIABLES f +VARIABLES + \* @type: Int; + f RECURSIVE Fact(_) diff --git a/test/tla/Rec12.tla b/test/tla/Rec12.tla index 7aa036423b..a6f4ef9762 100644 --- a/test/tla/Rec12.tla +++ b/test/tla/Rec12.tla @@ -10,7 +10,10 @@ RECURSIVE A(_) A(x) == IF x < 1 THEN x ELSE 1 + A(x - 1) ====================================================================== -VARIABLES f +VARIABLES + \* @type: Int; + f + I == INSTANCE inner UNROLL_DEFAULT_I_A == 0 diff --git a/test/tla/Rec13.tla b/test/tla/Rec13.tla index 933b448a55..76cfc76bd0 100644 --- a/test/tla/Rec13.tla +++ b/test/tla/Rec13.tla @@ -1,7 +1,9 @@ ----------------- MODULE Rec13 ---------------------- EXTENDS Integers -VARIABLE y +VARIABLE + \* @type: Int; + y LOCAL Send[x \in { 1, 2}] == x \* at this point, we expect x = 2, not x = 1 @@ -13,4 +15,4 @@ Init == y = SendAll(1) Next == UNCHANGED y Inv == y = 2 -===================================================== \ No newline at end of file +===================================================== diff --git a/test/tla/Rec2.tla b/test/tla/Rec2.tla index 092d80952b..922ee82eb2 100644 --- a/test/tla/Rec2.tla +++ b/test/tla/Rec2.tla @@ -6,16 +6,17 @@ *) EXTENDS Integers -VARIABLES size, set - -a <: b == a -IntSet(S) == S <: {Int} +VARIABLES + \* @type: Int; + size, + \* @type: Set(Int); + set RECURSIVE Card(_) \* this is very close to how cardinality is computed in Apalache Card(S) == - IF S = IntSet({}) + IF S = {} THEN 0 ELSE \* CHOOSE is introduced with a LET definition, to fix its value, whatever it may be @@ -31,7 +32,7 @@ UNROLL_TIMES_Card == 10 UNROLL_DEFAULT_Card == 0 Init == - /\ set = IntSet({}) + /\ set = {} /\ size = 0 Next == diff --git a/test/tla/Rec3.tla b/test/tla/Rec3.tla index 43e4d28456..5e153c9695 100644 --- a/test/tla/Rec3.tla +++ b/test/tla/Rec3.tla @@ -8,10 +8,15 @@ *) EXTENDS Integers -VARIABLES n, fibComp, fibCompPrev, fibSpec - -\* the syntax for type annotations -a <: b == a +VARIABLES + \* @type: Int; + n, + \* @type: Int; + fibComp, + \* @type: Int; + fibCompPrev, + \* @type: Int; + fibSpec \* the type of the function Fib FibT == [Int -> Int] @@ -22,7 +27,7 @@ Fib[k \in 0..15] == THEN 0 ELSE IF k <= 2 THEN 1 - ELSE (Fib <: FibT)[k - 2] + (Fib <: FibT)[k - 1] + ELSE Fib[k - 2] + Fib[k - 1] Init == /\ n = 0 diff --git a/test/tla/Rec4.tla b/test/tla/Rec4.tla index 7d7cb5da79..dd587e4907 100644 --- a/test/tla/Rec4.tla +++ b/test/tla/Rec4.tla @@ -6,7 +6,9 @@ *) EXTENDS Integers -VARIABLES f +VARIABLES + \* @type: Int; + f RECURSIVE Fib(_) diff --git a/test/tla/Rec5.tla b/test/tla/Rec5.tla index 4565d59504..5023028f49 100644 --- a/test/tla/Rec5.tla +++ b/test/tla/Rec5.tla @@ -9,16 +9,14 @@ EXTENDS Integers, FiniteSets MAX_POWER == 3 \* the maximal voting power Procs == {"a", "b", "c"} \* the set of processes -VARIABLES votingPower - -a <: b == a - -StrSet(S) == S <: {STRING} +VARIABLES + \* @type: Str -> Int; + votingPower RECURSIVE Sum(_) Sum(S) == - IF S = StrSet({}) + IF S = {} THEN 0 ELSE LET x == CHOOSE y \in S: TRUE IN votingPower[x] + Sum(S \ {x}) diff --git a/test/tla/Rec6.tla b/test/tla/Rec6.tla index 7dd2b5fbdb..d429ffd078 100644 --- a/test/tla/Rec6.tla +++ b/test/tla/Rec6.tla @@ -3,14 +3,16 @@ EXTENDS Integers N == 5 -VARIABLES set, count - -a <: b == a +VARIABLES + \* @type: Set(Int); + set, + \* @type: Int; + count RECURSIVE Sum(_) Sum(S) == - IF S = {} <: {Int} + IF S = {} THEN 0 ELSE LET x == CHOOSE y \in S: TRUE IN x + Sum(S \ {x}) @@ -19,7 +21,7 @@ UNROLL_DEFAULT_Sum == 0 UNROLL_TIMES_Sum == N Init == - /\ set = {} <: {Int} + /\ set = {} /\ count = 0 Next == diff --git a/test/tla/Rec8.tla b/test/tla/Rec8.tla index 3178f4b185..2e009d87f1 100644 --- a/test/tla/Rec8.tla +++ b/test/tla/Rec8.tla @@ -1,13 +1,13 @@ ------------------------------ MODULE Rec8 ------------------------------------ EXTENDS Integers -VARIABLES n, factSpec, factComp - -\* the syntax for type annotations -a <: b == a - -\* the type of the factorial function -FactT == [Int -> Int] +VARIABLES + \* @type: Int; + n, + \* @type: Int; + factSpec, + \* @type: Int; + factComp (* Defining a recursive function on a finite domain. Although it is rather @@ -19,7 +19,7 @@ FactT == [Int -> Int] Fact[k \in 1..20] == IF k <= 1 THEN 1 - ELSE k * (Fact <: FactT)[k - 1] + ELSE k * Fact[k - 1] Init == /\ n = 1 diff --git a/test/tla/Rec9.tla b/test/tla/Rec9.tla index a8e1e0c346..4df04d92aa 100644 --- a/test/tla/Rec9.tla +++ b/test/tla/Rec9.tla @@ -15,14 +15,6 @@ VARIABLES \* @type: Int; size -\* the syntax for type annotations -a <: b == a - -IntSet(S) == S <: {Int} - -\* the type of the function Card -CardT == [{Int} -> Int] - (* The set cardinality function. It needs an upper bound on the set size. Although this function looks nice, be warned that this definition requires us @@ -30,13 +22,13 @@ CardT == [{Int} -> Int] for the function Card. This encoding is (at least) double-exponential. *) Card[S \in SUBSET NUMS] == - IF S = IntSet({}) + IF S = {} THEN 0 ELSE LET i == CHOOSE j \in S: TRUE IN - 1 + (Card <: CardT)[S \ {i}] + 1 + Card[S \ {i}] Init == - /\ set = IntSet({}) + /\ set = {} /\ size = 0 Next == diff --git a/test/tla/Selections.tla b/test/tla/Selections.tla new file mode 100644 index 0000000000..cab28a8651 --- /dev/null +++ b/test/tla/Selections.tla @@ -0,0 +1,26 @@ +---------------------- MODULE Selections ---------------------- +VARIABLE + \* @type: Int; + x, + \* @type: Int; + y, + \* @type: Int; + z + +Init == + x = 0 /\ y = 1 /\ z = 2 + +Next == /\ /\ \/ y' \in {1} + \/ y' \in {2} + /\ \/ x' \in {3} + \/ x' \in {4} + \/ /\ z' \in {5} + /\ y' \in {6} + /\ \/ z' \in {7} + \/ z' \in {8} + /\ \/ x' \in {9} + \/ /\ x' \in {10} + /\ x' \in {11} + +============================================================== + diff --git a/tla-assignments/src/test/resources/assignmentSolver/SimpT1.tla b/test/tla/SimpT1.tla similarity index 79% rename from tla-assignments/src/test/resources/assignmentSolver/SimpT1.tla rename to test/tla/SimpT1.tla index ac55a040f1..328a770fd0 100644 --- a/tla-assignments/src/test/resources/assignmentSolver/SimpT1.tla +++ b/test/tla/SimpT1.tla @@ -1,5 +1,8 @@ --------------------------- MODULE SimpT1 --------------------------- (* + The below specification is written for testing purposes. It does not reflect + any real algorithm. + A TLA+ specification of from the synchronous model with symmetric Byzantine faults (Algorithm 1). @@ -19,10 +22,24 @@ InvalidValues == 2..2 \* e.g., sent by a Byzantine process nil == -1 \* these variables are exactly as in the pseudo-code -VARIABLES h, e, decision, proposal, vote +VARIABLES + \* @type: Int -> Int; + h, + \* @type: Int -> Int; + e, + \* @type: Int -> (Int -> Int); + decision, + \* @type: Int -> Int; + proposal, + \* @type: Int -> Int; + vote \* book-keeping variables -VARIABLES round, faultyMessages +VARIABLES + \* @type: Str; + round, + \* @type: Set([type: Str, src: Int, height: Int, epoch: Int, proposal: Int, vote: Int]); + faultyMessages Init == /\ h = [p \in Procs |-> 0] /\ e = [p \in Procs |-> 0] @@ -48,7 +65,7 @@ SentCorrect(p) == ELSE {} [] round = "PROPOSE" -> { [type |-> "PROPOSE", height |-> h[p], epoch |-> e[p], proposal |-> proposal[p]] } - [] round = "VOTE" -> + [] OTHER (*round = "VOTE"*) -> { [type |-> "VOTE", height |-> h[p], epoch |-> e[p], vote |-> vote[p]] } PreProposeFaulty == @@ -64,12 +81,12 @@ VoteFaulty == epoch: Epochs, vote: ValidValues \cup InvalidValues] IsSentByFaulty == - faultyMessages' \in - CASE round = "PRE-PROPOSE" -> PreProposeFaulty - [] round = "PROPOSE" -> ProposeFaulty - [] round = "VOTE" -> VoteFaulty + CASE round = "PRE-PROPOSE" -> faultyMessages' \in PreProposeFaulty + [] round = "PROPOSE" -> faultyMessages' \in ProposeFaulty + [] OTHER (*round = "VOTE"*) -> faultyMessages' \inVoteFaulty +\* @type: (Int, Set([type: Str, src: Int, height: Int, epoch: Int, proposal: Int, vote: Int])) => Set([type: Str, src: Int, height: Int, epoch: Int, proposal: Int, vote: Int]); Deliver(p, sent) == CASE round = "PRE-PROPOSE" -> { m \in sent : m.type = "PRE-PROPOSE" /\ m.src = Proposer(p) /\ m.height = h[p] /\ m.epoch = e[p] } @@ -77,11 +94,16 @@ Deliver(p, sent) == [] round = "PROPOSE" -> { m \in sent : m.type = "PROPOSE" /\ m.height = h[p] /\ m.epoch = e[p] } - [] round = "VOTE" -> + [] OTHER (*round = "VOTE"*) -> { m \in sent : m.type = "VOTE" /\ m.height = h[p] /\ m.epoch = e[p] } +\* @type: (Int, Int -> Set([type: Str, src: Int, height: Int, epoch: Int, proposal: Int])) => Int; FindProposal(p, delivered) == - LET vs == { m2.proposal: m2 \in { m \in delivered[p] : m.type = "PRE-PROPOSE" /\ m.height = h[p] /\ m.epoch = h[p] /\ IsValid(m.proposal) }} IN + LET + \* @type: Set(Int); + vs == { m2.proposal: m2 \in { m \in delivered[p] : + m.type = "PRE-PROPOSE" /\ m.height = h[p] /\ m.epoch = h[p] /\ IsValid(m.proposal) }} + IN IF vs = {} THEN nil ELSE CHOOSE v \in vs: TRUE \* actually, vs is a singleton set, but we use choose to pick an element @@ -91,6 +113,7 @@ PrePropose(delivered) == /\ proposal' = [p \in Procs |-> Id(FindProposal(p, delivered))] /\ UNCHANGED <> +\* @type: (Int, Int -> Set([type: Str, src: Int, height: Int, epoch: Int, proposal: Int, vote: Int])) => Int; ChooseValue(p, delivered) == LET vs == { m2.proposal: m2 \in { m \in delivered[p] : m.type = "PROPOSE" /\ m.height = h[p] /\ m.epoch = h[p] }} IN IF vs = {} @@ -104,6 +127,7 @@ Propose(delivered) == /\ vote' = [p \in Procs |-> Id(ChooseValue(p, delivered))] /\ UNCHANGED <> +\* @type: (Int, Int -> Set([type: Str, src: Int, height: Int, epoch: Int, proposal: Int, vote: Int])) => Int; ChooseDecision(p, delivered) == LET vs == { m2.vote: m2 \in { m \in delivered[p] : m.type = "VOTE" /\ m.height = h[p] /\ m.epoch = h[p] }} IN IF vs = {} @@ -111,6 +135,7 @@ ChooseDecision(p, delivered) == ELSE LET decided == CHOOSE v \in vs: TRUE IN IF IsValid(decided) THEN decided ELSE nil +\* @type: (Int, Int -> Int) => Bool; IsDecidedNow(p, d) == d[p] /= nil /\ decision[p][h[p]] = nil Overflow(p) == h[p] \notin Heights \/ e[p] \notin Epochs @@ -134,7 +159,7 @@ Vote(delivered) == IF Overflow(p) \/ IsDecidedNow(p, d) THEN proposal'[p] = proposal[p] ELSE proposal'[p] \in ValidValues - /\ UNCHANGED <> + /\ UNCHANGED vote Compute(delivered) == diff --git a/test/tla/Slicer1.tla b/test/tla/Slicer1.tla index 66d49433b5..686339445c 100644 --- a/test/tla/Slicer1.tla +++ b/test/tla/Slicer1.tla @@ -1,7 +1,9 @@ --------------------------- MODULE Slicer1 ------------------------------------ (* Testing slicing of symbolic transitions *) -VARIABLE state +VARIABLE + \* @type: Str; + state Init == state = "Init" diff --git a/test/tla/Slicer2.tla b/test/tla/Slicer2.tla index 1e8d2ee68d..c828173222 100644 --- a/test/tla/Slicer2.tla +++ b/test/tla/Slicer2.tla @@ -1,7 +1,9 @@ --------------------------- MODULE Slicer2 ------------------------------------ (* Testing slicing of symbolic transitions *) -VARIABLE state +VARIABLE + \* @type: Str; + state Init == state = "Init" diff --git a/test/tla/Slicer3.tla b/test/tla/Slicer3.tla index 3283b6d2e5..600fdc6631 100644 --- a/test/tla/Slicer3.tla +++ b/test/tla/Slicer3.tla @@ -1,7 +1,9 @@ --------------------------- MODULE Slicer3 ------------------------------------ (* Testing slicing of symbolic transitions *) -VARIABLE state +VARIABLE + \* @type: Str; + state Init == state = "Init" diff --git a/test/tla/Slicer4.tla b/test/tla/Slicer4.tla index 2443608852..44004a1e87 100644 --- a/test/tla/Slicer4.tla +++ b/test/tla/Slicer4.tla @@ -1,7 +1,9 @@ --------------------------- MODULE Slicer4 ------------------------------------ (* Testing slicing of symbolic transitions *) -VARIABLE state +VARIABLE + \* @type: Str; + state Init == state = "Init" diff --git a/test/tla/Slicer5.tla b/test/tla/Slicer5.tla index 0dad697868..741953a742 100644 --- a/test/tla/Slicer5.tla +++ b/test/tla/Slicer5.tla @@ -2,7 +2,13 @@ (* Testing slicing of symbolic transitions *) EXTENDS Integers, FiniteSets -VARIABLE state, n, k +VARIABLE + \* @type: Str; + state, + \* @type: Int; + n, + \* @type: Int; + k Init == /\ state = "Init" diff --git a/test/tla/ast.tla b/test/tla/ast.tla index 4faec9d586..23201049e4 100644 --- a/test/tla/ast.tla +++ b/test/tla/ast.tla @@ -10,7 +10,6 @@ Proc == 1..N NoPrnt == 10 root == 1 nbrs == { pair(1, 2), pair(2, 3), pair(2, 4), pair(3, 4), pair(4, 5), pair(5, 1) } -a <: b == b \*ASSUME NoPrnt \notin Proc /\ nbrs \subseteq Proc \times Proc VARIABLES @@ -25,7 +24,7 @@ vars == <> Init == /\ prnt = [i \in Proc |-> NoPrnt] /\ rpt = [i \in Proc |-> FALSE] - /\ msg = {} <: {<>} + /\ msg = {} CanSend(i, k) == (<> \in nbrs) /\ (i = root \/ prnt[i] # NoPrnt) @@ -42,14 +41,14 @@ Parent(i) == /\ prnt[i] # NoPrnt /\ ~rpt[i] Next == \E i, j, k \in Proc: - IF i # root /\ prnt[i] = NoPrnt /\ <> \in msg + IF i # root /\ prnt[i] = NoPrnt /\ pair(j, i) \in msg THEN Update(i, j) ELSE \/ Send(i, k) \/ Parent(i) \/ UNCHANGED <> Spec == /\ Init /\ [][Next]_vars /\ WF_vars(\E i, j, k \in Proc: - IF i # root /\ prnt[i] = NoPrnt /\ <> \in msg + IF i # root /\ prnt[i] = NoPrnt /\ pair(j, i) \in msg THEN Update(i, j) ELSE \/ Send(i, k) \/ Parent(i) \/ UNCHANGED <>) @@ -62,5 +61,5 @@ Termination == <>(\A i \in Proc : i = root \/ (prnt[i] # NoPrnt /\ < OneParent == [][\A i \in Proc : prnt[i] # NoPrnt => prnt[i] = prnt'[i]]_vars -SntMsg == \A i \in Proc: (i # root /\ prnt[i] = NoPrnt => \A j \in Proc: <> \notin msg) +SntMsg == \A i \in Proc: (i # root /\ prnt[i] = NoPrnt => \A j \in Proc: pair(i ,j) \notin msg) ============================================================================= diff --git a/test/tla/cli-integration-tests.md b/test/tla/cli-integration-tests.md index f0bfcefeef..d7995b3c18 100644 --- a/test/tla/cli-integration-tests.md +++ b/test/tla/cli-integration-tests.md @@ -148,7 +148,7 @@ EXITCODE: OK #### check factorization find a counterexample ```sh -$ apalache-mc check --length=2 --inv=Inv --with-snowcat factorization.tla | sed 's/I@.*//' +$ apalache-mc check --length=2 --inv=Inv factorization.tla | sed 's/I@.*//' ... The outcome is: Error Checker has found an error @@ -158,7 +158,7 @@ Checker has found an error ### check Fix531.tla reports no error: regression for issue 531 ```sh -$ apalache-mc check --length=1 --with-snowcat Fix531.tla | sed 's/I@.*//' +$ apalache-mc check --length=1 Fix531.tla | sed 's/I@.*//' ... The outcome is: NoError ... @@ -167,7 +167,7 @@ The outcome is: NoError ### check UnchangedExpr471.tla reports no error: regression for issue 471 ```sh -$ apalache-mc check --cinit=ConstInit --length=1 --with-snowcat UnchangedExpr471.tla | sed 's/I@.*//' +$ apalache-mc check --cinit=ConstInit --length=1 UnchangedExpr471.tla | sed 's/I@.*//' ... The outcome is: NoError ... @@ -185,7 +185,7 @@ The outcome is: NoError ### check InvSub for SafeMath reports no error: regression for issue 450 ```sh -$ apalache-mc check --length=1 --inv=InvSub --with-snowcat SafeMath.tla | sed 's/I@.*//' +$ apalache-mc check --length=1 --inv=InvSub SafeMath.tla | sed 's/I@.*//' ... The outcome is: NoError ... @@ -194,7 +194,7 @@ The outcome is: NoError ### check InvAdd for SafeMath reports no error: regression for issue 450 ```sh -$ apalache-mc check --length=1 --inv=InvAdd --with-snowcat SafeMath.tla | sed 's/I@.*//' +$ apalache-mc check --length=1 --inv=InvAdd SafeMath.tla | sed 's/I@.*//' ... The outcome is: NoError ... @@ -203,7 +203,7 @@ The outcome is: NoError ### check Fix365_ExistsSubset succeeds: regression for issue 365 ```sh -$ apalache-mc check --length=10 --with-snowcat Fix365_ExistsSubset.tla | sed 's/I@.*//' +$ apalache-mc check --length=10 Fix365_ExistsSubset.tla | sed 's/I@.*//' ... The outcome is: NoError ... @@ -212,7 +212,7 @@ The outcome is: NoError ### check Fix365_ExistsSubset2 succeeds: regression for issue 365 ```sh -$ apalache-mc check --length=10 --with-snowcat Fix365_ExistsSubset2.tla | sed 's/I@.*//' +$ apalache-mc check --length=10 Fix365_ExistsSubset2.tla | sed 's/I@.*//' ... The outcome is: NoError ... @@ -221,7 +221,7 @@ The outcome is: NoError ### check Fix365_ExistsSubset3 succeeds: regression for issue 365 ```sh -$ apalache-mc check --length=10 --with-snowcat Fix365_ExistsSubset3.tla | sed 's/I@.*//' +$ apalache-mc check --length=10 Fix365_ExistsSubset3.tla | sed 's/I@.*//' ... The outcome is: NoError ... @@ -230,7 +230,7 @@ The outcome is: NoError ### check Bug20201118 succeeds: regression for issue 333 ```sh -$ apalache-mc check --length=10 --init=Init --next=Next --inv=Inv --with-snowcat Bug20201118.tla | sed 's/I@.*//' +$ apalache-mc check --length=10 --init=Init --next=Next --inv=Inv Bug20201118.tla | sed 's/I@.*//' ... The outcome is: NoError ... @@ -239,7 +239,7 @@ The outcome is: NoError ### check Fix333 succeeds: another regression for issue 333 ```sh -$ apalache-mc check --length=2 --init=Init --next=Next --inv=Inv --with-snowcat Fix333.tla | sed 's/I@.*//' +$ apalache-mc check --length=2 --init=Init --next=Next --inv=Inv Fix333.tla | sed 's/I@.*//' ... The outcome is: NoError ... @@ -248,7 +248,7 @@ The outcome is: NoError ### check Bug20190118 succeeds ```sh -$ apalache-mc check --length=1 --init=Init --next=Next --inv=Inv --with-snowcat Bug20190118.tla | sed 's/I@.*//' +$ apalache-mc check --length=1 --init=Init --next=Next --inv=Inv Bug20190118.tla | sed 's/I@.*//' ... The outcome is: NoError ... @@ -257,7 +257,7 @@ The outcome is: NoError ### check mis.tla succeeds ```sh -$ apalache-mc check --length=5 --inv=IsIndependent --with-snowcat mis.tla | sed 's/I@.*//' +$ apalache-mc check --length=5 --inv=IsIndependent mis.tla | sed 's/I@.*//' ... The outcome is: NoError ... @@ -266,7 +266,7 @@ The outcome is: NoError ### check mis_bug.tla errors ```sh -$ apalache-mc check --length=5 --inv=IsIndependent --with-snowcat mis_bug.tla | sed 's/I@.*//' +$ apalache-mc check --length=5 --inv=IsIndependent mis_bug.tla | sed 's/I@.*//' ... The outcome is: Error Checker has found an error @@ -640,9 +640,8 @@ The outcome is: NoError ### check Callback.tla succeeds -`Callback.tla` demonstrates that one can implement non-determinism with -the existential operator and use a callback to do an assignment to a variable. -As it requires tricky operator inlining, here is the test. +`Callback.tla` demonstrates that one can implement non-determinism with the existential operator and use a callback to +do an assignment to a variable. As it requires tricky operator inlining, here is the test. ```sh $ apalache-mc check Callback.tla | sed 's/I@.*//' @@ -651,6 +650,45 @@ The outcome is: NoError ... ``` +### check SimpT1 succeeds + +This test was moved from a unit test of SymbTransGenerator. The goal of the test is to check that symbolic transitions +are extracted from the spec. Hence, we run model checking only against the initial states. + +```sh +$ apalache-mc check --length=0 SimpT1.tla | sed 's/I@.*//' +... +The outcome is: NoError +... +``` + +### check Selections succeeds + +```sh +$ apalache-mc check Selections.tla | sed 's/I@.*//' +... +Selections.tla:16:18-16:27: Missing assignments to: z +... +EXITCODE: ERROR (99) +``` + +### check test1 succeeds + +```sh +$ apalache-mc check test1.tla | sed 's/I@.*//' +... +The outcome is: NoError +... +``` + +### check ITE_CASE succeeds + +```sh +$ apalache-mc check ITE_CASE.tla | sed 's/I@.*//' +... +EXITCODE: ERROR (99) +``` + ### check use of TLA_PATH for modules in child directory succeeds ```sh @@ -805,7 +843,7 @@ The outcome is: NoError ```sh $ apalache-mc check ConfigUnsorted.tla | sed 's/[IEW]@.*//' ... -Configuration error (see the manual): Found a cyclic dependency among operators: A, B, C +Configuration error (see the manual): Found a cyclic dependency among operators: B, A, C ... EXITCODE: ERROR (99) ``` @@ -1073,18 +1111,6 @@ Type checker [OK] EXITCODE: OK ``` -### typecheck HourClock.tla - -```sh -$ apalache-mc typecheck HourClock.tla | sed 's/[IEW]@.*//' -... -[HourClock.tla:6:12-6:13]: Undefined name hr. Introduce a type annotation. -... -Type checker [FAILED] -... -EXITCODE: OK -``` - ### typecheck HourClockTyped.tla ```sh diff --git a/test/tla/mis.tla b/test/tla/mis.tla index 022898fe02..11612d02e6 100644 --- a/test/tla/mis.tla +++ b/test/tla/mis.tla @@ -5,8 +5,6 @@ N == 3 N4 == 81 Nodes == 1..N -a <: b == a \* type annotations - VARIABLES \* @type: Set(<>); Nb, @@ -34,8 +32,7 @@ Init == \*/\ Nb = [ n \in Nodes |-> {Pred(n), Succ(n)} ] /\ awake = [n \in Nodes |-> TRUE] /\ rem_nbrs = [ u \in Nodes |-> { v \in Nodes : <> \in Nb}] /\ status = [n \in Nodes |-> "unknown"] - /\ msgs = [n \in Nodes |-> - ({} <: {[type |-> STRING, src |-> Int, val |-> Int ]})] + /\ msgs = [n \in Nodes |-> {}] Senders(u) == {v \in Nodes: awake[v] /\ u \in rem_nbrs[v] } @@ -55,10 +52,8 @@ Round1 == SentWinners(u) == IF \E w \in Senders(u): awake[w] /\ status[w] = "winner" - THEN {[type |-> "winner", src |-> u] - <: [type |-> STRING, src |-> Int, val |-> Int]} - ELSE ({} - <: {[type |-> STRING, src |-> Int, val |-> Int]}) + THEN {[type |-> "winner", src |-> u]} + ELSE {} IsLoser(u) == \E m \in msgs'[u]: m.type = "winner" @@ -70,8 +65,7 @@ Round2 == /\ UNCHANGED <> SentLosers(u) == - {([type |-> "loser", src |-> s] - <: [type |-> STRING, src |-> Int, val |-> Int]) + {[type |-> "loser", src |-> s] : s \in {w \in Senders(u): awake[w] /\ status[w] = "loser"}} ReceivedLosers(u) == diff --git a/test/tla/mis_bug.tla b/test/tla/mis_bug.tla index 0606fbf66c..7191c259cb 100644 --- a/test/tla/mis_bug.tla +++ b/test/tla/mis_bug.tla @@ -5,8 +5,6 @@ N == 3 N4 == 81 Nodes == 1..N -a <: b == a \* type annotations - VARIABLES \* @type: Set(<>); Nb, @@ -31,8 +29,7 @@ Init == /\ awake = [n \in Nodes |-> TRUE] /\ rem_nbrs = [ u \in Nodes |-> { v \in Nodes : <> \in Nb}] /\ status = [n \in Nodes |-> "unknown"] - /\ msgs = [n \in Nodes |-> - ({} <: {[type |-> STRING, src |-> Int, val |-> Int ]})] + /\ msgs = [n \in Nodes |-> {}] Senders(u) == {v \in Nodes: awake[v] /\ u \in rem_nbrs[v] } @@ -52,10 +49,8 @@ Round1 == SentWinners(u) == IF \E w \in Senders(u): awake[w] /\ status[w] = "winner" - THEN {[type |-> "winner", src |-> u] - <: [type |-> STRING, src |-> Int, val |-> Int]} - ELSE ({} - <: {[type |-> STRING, src |-> Int, val |-> Int]}) + THEN {[type |-> "winner", src |-> u]} + ELSE {} IsLoser(u) == \E m \in msgs'[u]: m.type = "winner" @@ -67,8 +62,7 @@ Round2 == /\ UNCHANGED <> SentLosers(u) == - {([type |-> "loser", src |-> s] - <: [type |-> STRING, src |-> Int, val |-> Int]) + {[type |-> "loser", src |-> s] : s \in {w \in Senders(u): awake[w] /\ status[w] = "loser"}} ReceivedLosers(u) == diff --git a/test/tla/pr.tla b/test/tla/pr.tla index c6e8cc5d7e..775a399f47 100644 --- a/test/tla/pr.tla +++ b/test/tla/pr.tla @@ -19,10 +19,14 @@ N == 4 dest == 1 Nodes == 1..N -a <: b == a -EmptyIntSet == {} <: {Int} -VARIABLES out, in, to_rev +VARIABLES + \* @type: Int -> Set(Int); + out, + \* @type: Int -> Set(Int); + in, + \* @type: Int -> Set(Int); + to_rev IsDag == \E po \in [Nodes -> [Nodes -> BOOLEAN]]: \* there is a partial order, that is, @@ -32,29 +36,30 @@ IsDag == /\ \A u \in Nodes: \A v \in out[u]: po[u][v] \* embeds neighborhood InitOut == - \* [ u \in Nodes |-> IF u = 1 THEN {2} ELSE EmptyIntSet ] + \* [ u \in Nodes |-> IF u = 1 THEN {2} ELSE {} ] [ u \in Nodes |-> IF u = 2 THEN {3, 5} ELSE IF u \in { 3, 7} THEN {4} ELSE IF u = 5 THEN {6} ELSE IF u = 6 THEN {3, 7} - ELSE EmptyIntSet \* u \in {1, 4} + ELSE {} \* u \in {1, 4} ] Init == /\ out = InitOut /\ in = [u \in Nodes |-> {v \in Nodes: u \in out[v]}] - /\ to_rev = [n \in Nodes |-> EmptyIntSet (*{}*)] + /\ to_rev = [n \in Nodes |-> {}] \*/\ out \in [Nodes -> SUBSET(Nodes)] \*/\ IsDag -Sinks == { n \in Nodes : out[n] = EmptyIntSet (*{}*) /\ n /= dest} +Sinks == { n \in Nodes : out[n] = {} /\ n /= dest} Reversed(Active) == [ u \in Nodes |-> \* TODO: figure out, why it does not work with Active IF u \notin Active - THEN EmptyIntSet + THEN {} ELSE IF in[u] /= to_rev[u] THEN in[u] \ to_rev[u] ELSE in[u] ] +\* @type: (Set(Int), Int -> Set(Int)) => Bool; UpdateOut(Active, rev) == out' = [ u \in Nodes |-> IF u \in Active @@ -63,6 +68,7 @@ UpdateOut(Active, rev) == \* another node => remove the nodes that active sinks reversed ] +\* @type: (Set(Int), Int -> Set(Int)) => Bool; UpdateIn(Active, rev) == in' = [ u \in Nodes |-> IF u \in Active @@ -71,10 +77,11 @@ UpdateIn(Active, rev) == \* another node => add the nodes from the sinks that reversed this edge ] +\* @type: (Set(Int), Int -> Set(Int)) => Bool; UpdateToRev(Active, rev) == to_rev' = [ u \in Nodes |-> IF u \in Active - THEN EmptyIntSet (*{}*) \* empty the list of the links to reverse + THEN {} ELSE to_rev[u] \cup { s \in Active : u \in rev[s] } \* when a neighbor s of u makes a reversal, u adds \* the link between u and s to the list diff --git a/test/tla/reorderTest.tla b/test/tla/reorderTest.tla index 51be9379be..eeae85841e 100644 --- a/test/tla/reorderTest.tla +++ b/test/tla/reorderTest.tla @@ -2,7 +2,11 @@ EXTENDS Integers -VARIABLE v1, v2 +VARIABLE + \* @type: Int; + v1, + \* @type: Int; + v2 Init == v1 = 1 /\ v2 = 1 diff --git a/tla-assignments/src/test/resources/assignmentSolver/test1.tla b/test/tla/test1.tla similarity index 72% rename from tla-assignments/src/test/resources/assignmentSolver/test1.tla rename to test/tla/test1.tla index dbc05affb1..665effdd7b 100644 --- a/tla-assignments/src/test/resources/assignmentSolver/test1.tla +++ b/test/tla/test1.tla @@ -1,6 +1,14 @@ --------------- MODULE test1 ------------- EXTENDS Naturals (*, TLC , Sequences*) -VARIABLE x, y, z, w +VARIABLE + \* @type: Set(Int); + x, + \* @type: Set(Str); + y, + \* @type: Set([a: Int, b: Int, c: Int]); + z, + \* @type: Set(Int -> Int); + w Init == (*/\ Print("Should find only one distinct state", TRUE)*) @@ -20,4 +28,4 @@ Next == \/ /\ x' = {3, 3, 2, 1} /\ UNCHANGED <> (*/\ Print("Test 1", TRUE)*) -============================================ \ No newline at end of file +============================================ diff --git a/test/tla/tla-path-tests/ImportedModule.tla b/test/tla/tla-path-tests/ImportedModule.tla index 9513db5103..72063b8a52 100644 --- a/test/tla/tla-path-tests/ImportedModule.tla +++ b/test/tla/tla-path-tests/ImportedModule.tla @@ -1,7 +1,9 @@ ---- MODULE ImportedModule ----------------------------------------------------- (* This trivial MODULE is just to be extended *) -VARIABLES x +VARIABLES + \* @type: Bool; + x Init == x = TRUE diff --git a/test/tla/y2k_instance.tla b/test/tla/y2k_instance.tla index 1ef5b71008..7aaf2156b0 100644 --- a/test/tla/y2k_instance.tla +++ b/test/tla/y2k_instance.tla @@ -4,7 +4,11 @@ * use INSTANCE. *) -VARIABLE year, hasLicense +VARIABLE + \* @type: Int; + year, + \* @type: Bool; + hasLicense INSTANCE y2k WITH BIRTH_YEAR <- 80, LICENSE_AGE <- 18 diff --git a/tla-assignments/src/main/scala/at/forsyte/apalache/tla/assignments/AssignmentOperatorIntroduction.scala b/tla-assignments/src/main/scala/at/forsyte/apalache/tla/assignments/AssignmentOperatorIntroduction.scala index b6370a8a28..00fb4ec96a 100644 --- a/tla-assignments/src/main/scala/at/forsyte/apalache/tla/assignments/AssignmentOperatorIntroduction.scala +++ b/tla-assignments/src/main/scala/at/forsyte/apalache/tla/assignments/AssignmentOperatorIntroduction.scala @@ -2,7 +2,6 @@ package at.forsyte.apalache.tla.assignments import at.forsyte.apalache.tla.lir.oper.{BmcOper, TlaActionOper, TlaOper} import at.forsyte.apalache.tla.lir._ -import at.forsyte.apalache.tla.lir.UntypedPredefs._ import at.forsyte.apalache.tla.lir.transformations.{TlaExTransformation, TransformationTracker} /** @@ -20,16 +19,18 @@ class AssignmentOperatorIntroduction( def transform: TlaExTransformation = tracker.trackEx { case ex @ OperEx(TlaOper.eq, prime @ OperEx(TlaActionOper.prime, _: NameEx), asgnVal) if isAssignment(ex.ID) => - val ret = OperEx(BmcOper.assign, prime, asgnVal) + val ret = OperEx(BmcOper.assign, prime, asgnVal)(ex.typeTag) uidReplacementMap += ex.ID -> ret.ID ret + case ex @ OperEx(op, args @ _*) => val newArgs = args map transform - if (args == newArgs) ex else OperEx(op, newArgs: _*) + if (args == newArgs) ex else OperEx(op, newArgs: _*)(ex.typeTag) + case ex @ LetInEx(body, defs @ _*) => val newDefs = defs.map { x => x.copy(body = transform(x.body)) } val newBody = transform(body) - if (defs == newDefs && body == newBody) ex else LetInEx(newBody, newDefs: _*) + if (defs == newDefs && body == newBody) ex else LetInEx(newBody, newDefs: _*)(ex.typeTag) case ex => ex } diff --git a/tla-assignments/src/main/scala/at/forsyte/apalache/tla/assignments/ModuleAdapter.scala b/tla-assignments/src/main/scala/at/forsyte/apalache/tla/assignments/ModuleAdapter.scala index e28332e829..70408e4af7 100644 --- a/tla-assignments/src/main/scala/at/forsyte/apalache/tla/assignments/ModuleAdapter.scala +++ b/tla-assignments/src/main/scala/at/forsyte/apalache/tla/assignments/ModuleAdapter.scala @@ -1,7 +1,6 @@ package at.forsyte.apalache.tla.assignments import at.forsyte.apalache.tla.lir._ -import at.forsyte.apalache.tla.typecheck.{OperT1, TlaType1} /** * Moving away from SpecWithTransitions ModuleManipulator allows us to re-insert special TlaEx diff --git a/tla-assignments/src/main/scala/at/forsyte/apalache/tla/assignments/SymbTransGenerator.scala b/tla-assignments/src/main/scala/at/forsyte/apalache/tla/assignments/SymbTransGenerator.scala index 6b4958bda3..2a7ddf9af8 100644 --- a/tla-assignments/src/main/scala/at/forsyte/apalache/tla/assignments/SymbTransGenerator.scala +++ b/tla-assignments/src/main/scala/at/forsyte/apalache/tla/assignments/SymbTransGenerator.scala @@ -1,10 +1,9 @@ package at.forsyte.apalache.tla.assignments -import at.forsyte.apalache.tla.lir._ +import at.forsyte.apalache.tla.lir.{BoolT1, _} import at.forsyte.apalache.tla.lir.oper._ import at.forsyte.apalache.tla.lir.transformations.{TlaExTransformation, TransformationTracker} import at.forsyte.apalache.tla.lir.values.TlaBool -import at.forsyte.apalache.tla.lir.UntypedPredefs._ /** * Constructs symbolic transitions from an assignment strategy. @@ -185,7 +184,7 @@ class SymbTransGenerator(tracker: TransformationTracker) { newArgs match { case Nil => ex case head +: Nil => sliceWith(selection, allSelections)(head) - case _ => OperEx(TlaBoolOper.or, newArgs map sliceWith(selection, allSelections): _*) + case _ => OperEx(TlaBoolOper.or, newArgs map sliceWith(selection, allSelections): _*)(ex.typeTag) } /** @@ -201,19 +200,19 @@ class SymbTransGenerator(tracker: TransformationTracker) { } ) sliceWith(selection, allSelections)(x) - else ValEx(TlaBool(false)) + else ValEx(TlaBool(false))(ex.typeTag) ) newTail match { case ValEx(TlaBool(false)) +: ValEx(TlaBool(false)) +: Nil => ex - case newThen +: ValEx(TlaBool(false)) +: Nil => - OperEx(TlaBoolOper.and, ifEx, newThen) - case ValEx(TlaBool(false)) +: newElse +: Nil => - OperEx(TlaBoolOper.and, OperEx(TlaBoolOper.not, ifEx), newElse) + case newThen +: (b @ ValEx(TlaBool(false))) +: Nil => + OperEx(TlaBoolOper.and, ifEx, newThen)(b.typeTag) + case (b @ ValEx(TlaBool(false))) +: newElse +: Nil => + OperEx(TlaBoolOper.and, OperEx(TlaBoolOper.not, ifEx)(b.typeTag), newElse)(b.typeTag) case _ => // Possible, because of LET-IN - OperEx(TlaControlOper.ifThenElse, ifEx +: newTail: _*) + OperEx(TlaControlOper.ifThenElse, ifEx +: newTail: _*)(ex.typeTag) } case ex @ OperEx(op, args @ _*) => @@ -222,7 +221,7 @@ class SymbTransGenerator(tracker: TransformationTracker) { } // Make sure to avoid creating new UIDs if not absolutely needed, as filtering // is done on the basis of UIDs not syntax - if (childVals == args) ex else OperEx(op, childVals: _*) + if (childVals == args) ex else OperEx(op, childVals: _*)(ex.typeTag) case ex @ LetInEx(body, defs @ _*) => val slice = sliceWith(selection, allSelections) @@ -235,8 +234,8 @@ class SymbTransGenerator(tracker: TransformationTracker) { val same = newDefs == defs && newBody == body if (same) ex - else if (newBody == ValEx(TlaBool(false))) newBody - else LetInEx(newBody, newDefs: _*) + else if (newBody == ValEx(TlaBool(false))(Typed(BoolT1()))) newBody + else LetInEx(newBody, newDefs: _*)(ex.typeTag) case ex => ex } diff --git a/tla-assignments/src/test/resources/assignmentSolver/Selections.tla b/tla-assignments/src/test/resources/assignmentSolver/Selections.tla deleted file mode 100644 index a233475d29..0000000000 --- a/tla-assignments/src/test/resources/assignmentSolver/Selections.tla +++ /dev/null @@ -1,17 +0,0 @@ ----------------------- MODULE Selections ---------------------- -VARIABLE x,y,z - -Next == /\ /\ \/ y' \in 1 - \/ y' \in 2 - /\ \/ x' \in 3 - \/ x' \in 4 - \/ /\ z' \in 5 - /\ y' \in 6 - /\ \/ z' \in 7 - \/ z' \in 8 - /\ \/ x' \in 9 - \/ /\ x' \in 10 - /\ x' \in 11 - -============================================================== - diff --git a/tla-assignments/src/test/scala/at/forsyte/apalache/tla/assignments/TestAlphaTransform.scala b/tla-assignments/src/test/scala/at/forsyte/apalache/tla/assignments/TestAlphaTransform.scala index 426070698c..e8fad1dfbe 100644 --- a/tla-assignments/src/test/scala/at/forsyte/apalache/tla/assignments/TestAlphaTransform.scala +++ b/tla-assignments/src/test/scala/at/forsyte/apalache/tla/assignments/TestAlphaTransform.scala @@ -6,7 +6,7 @@ import at.forsyte.apalache.tla.lir.convenience._ import at.forsyte.apalache.tla.lir.storage.{BodyMapFactory, ChangeListener} import at.forsyte.apalache.tla.lir.transformations.impl.TrackerWithListeners import at.forsyte.apalache.tla.lir.transformations.standard._ -import at.forsyte.apalache.tla.pp.Desugarer +import at.forsyte.apalache.tla.pp.{Desugarer, UniqueNameGenerator} import at.forsyte.apalache.tla.lir.UntypedPredefs._ import org.junit.runner.RunWith import org.scalatest.FunSuite @@ -16,32 +16,10 @@ import org.scalatest.junit.JUnitRunner class TestAlphaTransform extends FunSuite with TestingPredefs { val testFolderPath = "src/test/resources/assignmentSolver/" - def specFromFile(p_file: String, p_next: String = "Next"): TlaEx = { - val declsRaw = declarationsFromFile(testFolderPath + p_file) - - val fakeModule = new TlaModule("test", declsRaw) - - val tracker = TrackerWithListeners(new ChangeListener) - - val renaming = new IncrementalRenaming(tracker) - val uniqueVarDecls = - new TlaModule( - fakeModule.name, - renaming.syncAndNormalizeDs(fakeModule.declarations).toSeq - ) - - val bodyMap = BodyMapFactory.makeFromDecls(uniqueVarDecls.operDeclarations) - val inlined = ModuleByExTransformer(InlinerOfUserOper(bodyMap, tracker))(uniqueVarDecls) - val explLetIn = ModuleByExTransformer(LetInExpander(tracker, keepNullary = false))(inlined) - val preprocessed = ModuleByExTransformer(Desugarer(tracker))(explLetIn) - - findBodyOf(p_next, preprocessed.declarations: _*) - } - test("Star abstraction") { val ex1 = trueEx - val ex2: TlaEx = 5 + val ex2: TlaEx = tla.int(5) val ex3: TlaEx = tla.in(n_x, n_S) val ex4: TlaEx = tla.choose(n_x, n_S, n_p) val ex5: TlaEx = tla.caseOther(n_c, n_p, n_a, n_q, n_b) @@ -127,10 +105,4 @@ class TestAlphaTransform extends FunSuite with TestingPredefs { assert(correctRecursiveApplication(Seq(ex1, ex2))) } - - test("Real spec") { - val spec = specFromFile("Paxos.tla") - - assert(correctRecursiveApplication(Seq(spec))) - } } diff --git a/tla-assignments/src/test/scala/at/forsyte/apalache/tla/assignments/TestSymbTransGenerator.scala b/tla-assignments/src/test/scala/at/forsyte/apalache/tla/assignments/TestSymbTransGenerator.scala index 1754b7607d..db72caf62e 100644 --- a/tla-assignments/src/test/scala/at/forsyte/apalache/tla/assignments/TestSymbTransGenerator.scala +++ b/tla-assignments/src/test/scala/at/forsyte/apalache/tla/assignments/TestSymbTransGenerator.scala @@ -3,7 +3,8 @@ package at.forsyte.apalache.tla.assignments import at.forsyte.apalache.tla.lir.oper.TlaActionOper import at.forsyte.apalache.tla.lir.transformations.impl.{IdleTracker, TrackerWithListeners} import at.forsyte.apalache.tla.lir._ -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.lir.convenience.tla +import TypedPredefs._ import org.junit.runner.RunWith import org.scalatest.FunSuite import org.scalatest.junit.JUnitRunner @@ -16,6 +17,8 @@ class TestSymbTransGenerator extends FunSuite with TestingPredefs { import stg.helperFunctions._ import at.forsyte.apalache.tla.lir.convenience.tla._ + private val types = Map("i" -> IntT1(), "b" -> BoolT1(), "o_b" -> OperT1(Seq(), BoolT1())) + test("Test allCombinations") { assert(allCombinations[Int](Seq.empty[Set[Set[Int]]]).isEmpty) @@ -57,9 +60,14 @@ class TestSymbTransGenerator extends FunSuite with TestingPredefs { } test("Test labelsAt") { - val ex11 = n_x - val ex12 = n_y - val ex13 = or(ex11, ex12).untyped() + val ex11 = tla + .name("x") + .typed(IntT1()) + val ex12 = tla + .name("y") + .typed(IntT1()) + val ex13 = or(ex11, ex12) + .typed(BoolT1()) val sel1: SelMapType = Map( ex13.ID -> Set(Set(ex11.ID), Set(ex12.ID)) @@ -78,24 +86,33 @@ class TestSymbTransGenerator extends FunSuite with TestingPredefs { } test("Test allSelections") { - val xasgn11 = primeEq(n_x, n_s).untyped() - val xasgn12 = primeEq(n_x, int(1)).untyped() - val yasgn11 = primeEq(n_x, n_T).untyped() - val yasgn12 = primeEq(n_x, n_t).untyped() + val xasgn11 = tla + .eql(tla.prime(tla.name("x") ? "i") ? "i", tla.name("s") ? "i") + .typed(types, "b") + val xasgn12 = tla + .eql(tla.prime(tla.name("x") ? "i") ? "i", int(1)) + .typed(types, "b") + val yasgn11 = tla + .eql(tla.prime(tla.name("x") ? "i") ? "i", tla.name("T") ? "i") + .typed(types, "b") + val yasgn12 = tla + .eql(tla.prime(tla.name("x") ? "i") ? "i", tla.name("t") ? "i") + .typed(types, "b") val ex1 = ite( - ge(int(0), int(1)), + ge(int(0), int(1)) ? "b", xasgn11, xasgn12 - ).untyped() + ).typed(types, "b") - val ex2 = or(yasgn11, yasgn12).untyped() + val ex2 = or(yasgn11, yasgn12) + .typed(types, "b") val ex3 = and( ex1, ex2 - ).untyped() + ).typed(types, "b") val possibleAssgnsX = Seq( Set(xasgn11.ID), @@ -132,16 +149,28 @@ class TestSymbTransGenerator extends FunSuite with TestingPredefs { assert(s(newEx.ID) == Set(e)) } - val xasgn21 = primeEq(n_x, n_s).untyped() - val yasgn21 = primeEq(n_x, n_T).untyped() - val yasgn22 = primeEq(n_y, n_t).untyped() - - val ex4 = and(eql(int(0), int(1)), xasgn21).untyped() - val xDecl = declOp("X", ex4).untypedOperDecl() - val ex5 = and(yasgn21, appDecl(xDecl)).untyped() - val ex6 = and(yasgn22, appDecl(xDecl)).untyped() - val ex7 = or(ex5, ex6).untyped() - val ex8 = letIn(ex7, xDecl).untyped() + val xasgn21 = tla + .eql(tla.prime(tla.name("x") ? "i") ? "i", tla.name("x") ? "i") + .typed(types, "b") + val yasgn21 = tla + .eql(tla.prime(tla.name("x") ? "i") ? "i", tla.name("T") ? "i") + .typed(types, "b") + val yasgn22 = tla + .eql(tla.prime(tla.name("y") ? "i") ? "i", tla.name("t") ? "i") + .typed(types, "b") + + val ex4 = and(eql(int(0), int(1)) ? "b", xasgn21 ? "b") + .typed(types, "b") + val xDecl = declOp("X", ex4) + .typedOperDecl(OperT1(Seq(), BoolT1())) + val ex5 = and(yasgn21, tla.appOp(tla.name("X") ? "o_b") ? "b") + .typed(types, "b") + val ex6 = and(yasgn22, tla.appOp(tla.name("X") ? "o_b") ? "b") + .typed(types, "b") + val ex7 = or(ex5, ex6) + .typed(types, "b") + val ex8 = letIn(ex7, xDecl) + .typed(types, "b") val possibleAssgnsX2 = Seq(Set(xasgn21.ID)) @@ -169,19 +198,25 @@ class TestSymbTransGenerator extends FunSuite with TestingPredefs { } test("Test ITE with multibranching") { - val asgn1 = primeEq(n_x, int(1)).untyped() - val asgn2 = primeEq(n_x, int(2)).untyped() - val asgn3 = primeEq(n_x, int(3)).untyped() + val asgn1 = tla + .eql(tla.prime(tla.name("x") ? "i") ? "i", int(1)) + .typed(types, "b") + val asgn2 = tla + .eql(tla.prime(tla.name("x") ? "i") ? "i", int(2)) + .typed(types, "b") + val asgn3 = tla + .eql(tla.prime(tla.name("x") ? "i") ? "i", int(3)) + .typed(types, "b") val next = ite( - trueEx, + tla.bool(true).typed(), asgn1, ite( - trueEx, + tla.bool(true).typed(), asgn2, asgn3 - ) - ).untyped() + ) ? "b" + ).typed(types, "b") val sel = Seq(asgn1.ID, asgn2.ID, asgn3.ID) @@ -202,17 +237,20 @@ class TestSymbTransGenerator extends FunSuite with TestingPredefs { } test("Test LET-IN") { - val asgn = primeEq(n_x, int(1)).untyped() - val xDecl = declOp("X", asgn).untypedOperDecl() + val asgn = tla + .eql(tla.prime(tla.name("x") ? "i") ? "i", int(1)) + .typed(types, "b") + val xDecl = declOp("X", asgn) + .typedOperDecl(types, "o_b") val disj = or( - and(n_A, appDecl(xDecl)), - and(n_B, appDecl(xDecl)) - ).untyped() + and(tla.name("A") ? "b", tla.appOp(tla.name("X") ? "o_b") ? "b") ? "b", + and(tla.name("B") ? "b", tla.appOp(tla.name("X") ? "o_b") ? "b") ? "b" + ).typed(types, "b") val next = letIn( disj, xDecl - ).untyped() + ).typed(types, "b") val strat = Seq(asgn.ID) @@ -220,22 +258,27 @@ class TestSymbTransGenerator extends FunSuite with TestingPredefs { _._2 } assert(ret.size == 1) - val expected = letIn(disj, declOp("X", assignPrime(n_x, int(1)).untyped()).untypedOperDecl()).untyped() + val expected = letIn(disj, + declOp("X", assign(prime(tla.name("x") ? "i") ? "i", int(1)) ? "b") + .typedOperDecl(types, "o_b")) + .typed(types, "b") assert(expected == ret.head) } test("Test sliceWith") { - val asgn = primeEq(n_x, int(1)).untyped() - val xDecl = declOp("X", asgn).untypedOperDecl() + val asgn = eql(prime(tla.name("x") ? "i") ? "i", int(1)) + .typed(types, "b") + val xDecl = declOp("X", asgn) + .typedOperDecl(types, "o_b") val disj = or( - and(n_A, appDecl(xDecl)), - and(n_B, appDecl(xDecl)) - ).untyped() + and(name("A"), appOp(name("X") ? "o_b") ? "b") ? "b", + and(name("B"), appOp(name("X") ? "o_b") ? "b") ? "b" + ).typed(types, "b") val next = letIn( disj, xDecl - ).untyped() + ).typed(types, "b") val selection = Set(asgn.ID) val tr = AssignmentOperatorIntroduction(selection, new IdleTracker) diff --git a/tla-assignments/src/test/scala/at/forsyte/apalache/tla/assignments/TestSymbTransPass.scala b/tla-assignments/src/test/scala/at/forsyte/apalache/tla/assignments/TestSymbTransPass.scala index 37f13d2e28..36e02734a9 100644 --- a/tla-assignments/src/test/scala/at/forsyte/apalache/tla/assignments/TestSymbTransPass.scala +++ b/tla-assignments/src/test/scala/at/forsyte/apalache/tla/assignments/TestSymbTransPass.scala @@ -1,13 +1,12 @@ package at.forsyte.apalache.tla.assignments -import at.forsyte.apalache.tla.imp.declarationsFromFile -import at.forsyte.apalache.tla.lir.{NullEx, TestingPredefs, TlaDecl, TlaEx, TlaModule, TlaOperDecl, TlaVarDecl, UID} +import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.lir.convenience.tla import at.forsyte.apalache.tla.lir.storage.{BodyMapFactory, ChangeListener} import at.forsyte.apalache.tla.lir.transformations.impl.TrackerWithListeners import at.forsyte.apalache.tla.lir.transformations.standard._ -import at.forsyte.apalache.tla.lir.convenience.tla -import at.forsyte.apalache.tla.lir.UntypedPredefs._ -import at.forsyte.apalache.tla.pp.Desugarer +import at.forsyte.apalache.tla.lir._ +import at.forsyte.apalache.tla.pp.{Desugarer, UniqueNameGenerator} import org.junit.runner.RunWith import org.scalatest.FunSuite import org.scalatest.junit.JUnitRunner @@ -35,28 +34,14 @@ class TestSymbTransPass extends FunSuite with TestingPredefs { val bodyMap = BodyMapFactory.makeFromDecls(uniqueVarDecls.operDeclarations) val inlined = ModuleByExTransformer(InlinerOfUserOper(bodyMap, tracker))(uniqueVarDecls) val explLetIn = ModuleByExTransformer(LetInExpander(tracker, keepNullary = false))(inlined) - val preprocessed = ModuleByExTransformer(Desugarer(tracker))(explLetIn) + val gen = new UniqueNameGenerator + val preprocessed = ModuleByExTransformer(Desugarer(gen, tracker))(explLetIn) val vars = preprocessed.varDeclarations.map(_.name) SymbolicTransitionExtractor(tracker)(vars, preprocessed.operDeclarations.find(_.name == p_next).get.body) } - def testFromFile(p_file: String, p_next: String = "Next"): Seq[SymbTrans] = { - - val decls = declarationsFromFile(testFolderPath + p_file) - - val ret = testFromDecls(decls, p_next) - - // val saveWriter = new PrintWriter( new File( testFolderPath + "SymbNexts" + p_file ) ) - - // ret.foreach( x => saveWriter.println( "%s : \n %s\n".format( x._1.map( UniqueDB.get ) , x._2 ) ) ) - - // saveWriter.close() - - ret - } - test("Test labelsAt") { val gen = new SymbTransGenerator(TrackerWithListeners()) @@ -100,33 +85,4 @@ class TestSymbTransPass extends FunSuite with TestingPredefs { val trans = testFromDecls(decls) assert(trans.isEmpty) } - - test("Test Selections") { - val symbNexts = testFromFile("Selections.tla") - } - - test("Test Paxos (simplified)") { - val symbNexts = testFromFile("Paxos.tla") - } - - test("Test ITE_CASE") { - val symbNexts = testFromFile("ITE_CASE.tla") - } - - test("Test EWD840") { - val symbNexts = testFromFile("EWD840.tla") - } - - test("AST") { - val symbNexts = testFromFile("ast.tla") - } - - test("test1") { - val symbNexts = testFromFile("test1.tla") - } - - test("SimpT1") { - val symbNexts = testFromFile("SimpT1.tla") - val symbNexts2 = testFromFile("SimpT1.tla", "NextNoFaults") - } } diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/LazyEquality.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/LazyEquality.scala index 7be1f84078..d065ad9efb 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/LazyEquality.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/LazyEquality.scala @@ -4,10 +4,9 @@ import at.forsyte.apalache.tla.bmcmt.caches.{EqCache, EqCacheSnapshot} import at.forsyte.apalache.tla.bmcmt.implicitConversions._ import at.forsyte.apalache.tla.bmcmt.rewriter.{ConstSimplifierForSmt, Recoverable} import at.forsyte.apalache.tla.bmcmt.types._ -import at.forsyte.apalache.tla.lir.convenience.tla -import at.forsyte.apalache.tla.lir.{NullEx, TlaEx} - import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.lir.convenience.tla +import at.forsyte.apalache.tla.lir.{MalformedTlaError, NullEx, TlaEx} /** * Generate equality constraints between cells and cache them to avoid redundant constraints. @@ -123,9 +122,11 @@ class LazyEquality(rewriter: SymbStateRewriter) } else if (cacheEntry.isDefined) { state // do nothing } else if (!left.cellType.comparableWith(right.cellType)) { - // cells of incomparable types cannot be equal - eqCache.put(left, right, EqCache.FalseEntry()) - state + // Cells of incomparable types cannot be equal. + // This is a dangerous state, as the type checker should have caught this. Throw an error. + // It is not really a typing error, but an internal error that should be treated as such. + val msg = "Checking values of incomparable types for equality: %s and %s".format(left.cellType, right.cellType) + throw new MalformedTlaError(msg, state.ex) } else { // generate constraints val newState = @@ -149,8 +150,11 @@ class LazyEquality(rewriter: SymbStateRewriter) case (SeqT(_), SeqT(_)) => mkSeqEq(state, left, right) - case _ => - throw new CheckerException("Unexpected equality test", state.ex) + case (FinFunSetT(_, _), FinFunSetT(_, _)) => + mkFunSetEq(state, left, right) + + case (lt, rt) => + throw new CheckerException(s"Unexpected equality test over types $lt and $rt", state.ex) } // return the new state @@ -238,6 +242,22 @@ class LazyEquality(rewriter: SymbStateRewriter) } } + private def mkFunSetEq(state: SymbState, left: ArenaCell, right: ArenaCell): SymbState = { + val dom1 = state.arena.getDom(left) + val cdm1 = state.arena.getCdm(left) + val dom2 = state.arena.getDom(right) + val cdm2 = state.arena.getCdm(right) + var nextState = mkSetEq(state, dom1, dom2) + nextState = mkSetEq(nextState, cdm1, cdm2) + val eq = tla.equiv(tla.eql(left.toNameEx, right.toNameEx), + tla.and(tla.eql(dom1.toNameEx, dom2.toNameEx), tla.eql(cdm1.toNameEx, cdm2.toNameEx))) + rewriter.solverContext.assertGroundExpr(eq) + eqCache.put(left, right, EqCache.EqEntry()) + + // recover the original expression and theory + nextState.setRex(state.ex) + } + // statically empty sets should be handled with care private def mkEmptySetEq(state: SymbState, emptySet: ArenaCell, otherSet: ArenaCell): SymbState = { val otherElems = state.arena.getHas(otherSet) diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/ModelChecker.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/ModelChecker.scala deleted file mode 100644 index 7e18715c50..0000000000 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/ModelChecker.scala +++ /dev/null @@ -1,650 +0,0 @@ -package at.forsyte.apalache.tla.bmcmt - -import java.io.{File, FileWriter, PrintWriter, StringWriter} - -import at.forsyte.apalache.tla.bmcmt.analyses.{ExprGradeStore, FormulaHintsStore} -import at.forsyte.apalache.tla.bmcmt.rewriter.{ConstSimplifierForSmt, MetricProfilerListener, RewriterConfig} -import at.forsyte.apalache.tla.bmcmt.rules.aux.{CherryPick, MockOracle, Oracle} -import at.forsyte.apalache.tla.bmcmt.search.SearchStrategy -import at.forsyte.apalache.tla.bmcmt.search.SearchStrategy._ -import at.forsyte.apalache.tla.bmcmt.smt.{SolverConfig, SolverContext, Z3SolverContext} -import at.forsyte.apalache.tla.bmcmt.types._ -import at.forsyte.apalache.tla.bmcmt.util.TlaExUtil -import at.forsyte.apalache.tla.imp.src.SourceStore -import at.forsyte.apalache.tla.lir._ -import at.forsyte.apalache.tla.lir.convenience.tla -import at.forsyte.apalache.tla.lir.io._ -import at.forsyte.apalache.tla.lir.storage.{ChangeListener, SourceLocator} -import at.forsyte.apalache.tla.lir.values.TlaBool -import at.forsyte.apalache.tla.lir.UntypedPredefs._ -import com.typesafe.scalalogging.LazyLogging - -import scala.collection.immutable.SortedMap -import scala.collection.mutable.ListBuffer - -/** - * A bounded model checker using SMT. The checker itself does not implement a particular search. Instead, - * it queries a search strategy, e.g., DfsStrategy or BfsStrategy. - * - * We expect the invariant to be negated and written over prime variables. - * - * @author Igor Konnov - */ -@deprecated("Use SeqModelChecker") -class ModelChecker(typeFinder: TypeFinder[CellT], formulaHintsStore: FormulaHintsStore, changeListener: ChangeListener, - exprGradeStore: ExprGradeStore, sourceStore: SourceStore, checkerInput: CheckerInput, - searchStrategy: SearchStrategy, tuningOptions: Map[String, String], debug: Boolean = false, - profile: Boolean = false, checkRuntime: Boolean = false) - extends Checker with LazyLogging { - - import Checker._ - - class CancelSearchException(val outcome: Outcome.Value) extends Exception - - /** - * A stack of the symbolic states that might constitute a counterexample (the last state is on top). - */ - private var stack: List[(SymbState, Oracle)] = List() - private var typesStack: Seq[SortedMap[String, CellT]] = Seq() - private val solverContext: SolverContext = - new Z3SolverContext(SolverConfig(debug, profile, randomSeed = tuningOptions.getOrElse("smt.randomSeed", "0").toInt)) - // TODO: figure out why the preprocessor slows down invariant checking. Most likely, there is a bug. - // new PreproSolverContext(new Z3SolverContext(debug, profile)) - - private val rewriter: SymbStateRewriterImpl = - new SymbStateRewriterImpl(solverContext, typeFinder, exprGradeStore, - if (profile) { - Some(new MetricProfilerListener(sourceStore, changeListener, new File("profile.csv"))) - } else { - None - }) - - rewriter.formulaHintsStore = formulaHintsStore - rewriter.config = RewriterConfig(tuningOptions) - - private val stepFilters: Seq[String] = - tuningOptions.getOrElse("search.transitionFilter", ".*").split(",") - - private val invFilter: String = - tuningOptions.getOrElse("search.invariantFilter", "") - - private val invariantSplitByTransition: Boolean = - tuningOptions.getOrElse("search.invariant.split", "true").toLowerCase == "true" - - private val learnTransFromUnsat: Boolean = - tuningOptions.getOrElse("search.transition.learnFromUnsat", "").toLowerCase == "true" - - private val learnInvFromUnsat: Boolean = - tuningOptions.getOrElse("search.invariant.learnFromUnsat", "").toLowerCase == "true" - - private val transitionTimeout: Long = - BigInt(tuningOptions.getOrElse("search.transition.timeout", "0")).toLong - - private val invariantTimeout: Long = - BigInt(tuningOptions.getOrElse("search.invariant.timeout", "0")).toLong - - /** - * A list of transitions that are enabled at every step - */ - private var enabledList: Seq[Seq[Int]] = Seq() - - /** - * A set of CONSTANTS, which are special (rigid) variables, as they do not change in the course of execution. - */ - private val constants = Set(checkerInput.rootModule.constDeclarations.map(_.name): _*) - - /** - * Check all executions of a TLA+ specification up to a bounded number of steps. - * - * @return a verification outcome - */ - def run(): Outcome.Value = { - val initialArena = Arena.create(solverContext) - val dummyState = new SymbState(initialArena.cellTrue().toNameEx, initialArena, Binding()) - val outcome = - try { - val initConstState = initializeConstants(dummyState) - stack +:= (initConstState, new MockOracle(0)) - typesStack +:= typeFinder.varTypes // the type of CONSTANTS have been computed already - applySearchStrategy() - } catch { - case _: CancelSearchException => - Outcome.Error - - case err: CheckerException => - // try to get any info about the problematic source location - printRewriterSourceLoc() - throw err - } - // flush the logs - rewriter.dispose() - outcome - } - - /** - * Use the provided expression to initialize the constants - * - * @param state an initial state - * @return a new state with the constants initialized - */ - private def initializeConstants(state: SymbState): SymbState = { - val newState = - if (checkerInput.constInitPrimed.isEmpty) { - logger.info("No CONSTANT initializer given") - state - } else { - logger.info("Initializing CONSTANTS with the provided operator") - checkTypes(checkerInput.constInitPrimed.get) - val nextState = rewriter.rewriteUntilDone(state.setRex(checkerInput.constInitPrimed.get)) - // importantly, assert the constraints that are imposed by the expression - rewriter.solverContext.assertGroundExpr(nextState.ex) - // as the initializer was defined over the primed versions of the constants, shift them back to non-primed - shiftTypes(Set()) - nextState.setBinding(shiftBinding(nextState.binding, Set())) - } - - val constants = checkerInput.rootModule.constDeclarations.map(_.name) - val uninitialized = constants.filter(n => !newState.binding.contains(n)) - if (uninitialized.nonEmpty) { - logger.error("The following CONSTANTS are not initialized: " + uninitialized.mkString(", ")) - throw new CancelSearchException(Checker.Outcome.RuntimeError) - } - newState - } - - private def applySearchStrategy(): Outcome.Value = { - searchStrategy.getCommand match { - case Finish() => Outcome.NoError // done - - case FinishOnDeadlock() => - logger.error(s"Deadlock detected.") - if (solverContext.sat()) { - val filenames = dumpCounterexample(ValEx(TlaBool(true))) - logger.error(s"Check an execution leading to a deadlock state in any of ${filenames.mkString(", ")}") - } else { - logger.error("No model found that would describe the deadlock") - } - Outcome.Deadlock - - case BacktrackOnce() => - rewriter.pop() - logger.debug("Backtracking to level %d".format(rewriter.contextLevel)) - stack = stack.tail - typesStack = typesStack.tail - searchStrategy.registerResponse(Backtracked()) - applySearchStrategy() - - case NextStep(stepNo: Int, transitionNos: Seq[Int], popContext: Boolean) => - def filter(trs: Seq[TlaEx]): Seq[(TlaEx, Int)] = { - trs.zipWithIndex filter (p => transitionNos.contains(p._2)) - } - - assert(rewriter.contextLevel == stepNo) - val (state, _) = stack.head - val types = typesStack.head - typeFinder.reset(types) // set the variable type as they should be at this step - - val transitions = - if (stepNo == 0) filter(checkerInput.initTransitions) else filter(checkerInput.nextTransitions) - - // make the step - val transWithEnabled = findEnabledOrBugs(stepNo, state, transitions.toList) - val enabledExists = transWithEnabled.exists(_._2) - if (!enabledExists) { - // no push here, as the transition is disabled - searchStrategy.registerResponse(SearchStrategy.NextStepDisabled()) - } else { - rewriter.push() // this is needed for backtracking, LEVEL + 1 - val result = applyEnabled(stepNo, state, transWithEnabled) - assert(result.isDefined) - searchStrategy.registerResponse(SearchStrategy.NextStepFired()) - } - - applySearchStrategy() // next step - } - } - - private def findEnabledOrBugs(stepNo: Int, startingState: SymbState, - transitionsAndNos: List[(TlaEx, Int)]): List[((TlaEx, Int), Boolean)] = { - // find all the feasible transitions and check the invariant for each transition - logger.info( - "Step %d, level %d: checking if %d transition(s) are enabled and violate the invariant" - .format(stepNo, rewriter.contextLevel, transitionsAndNos.length)) - - def filterEnabled(state: SymbState, ts: List[(TlaEx, Int)]): List[((TlaEx, Int), Boolean)] = { - ts match { - case List() => List() - - case tranWithNo :: tail => - val (transition, transitionNo) = tranWithNo - if (!stepMatchesFilter(stepNo, transitionNo)) { - filterEnabled(state, tail) // just skip this transition - } else { - val erased = state.setBinding(forgetPrimed(state.binding)) - rewriter.push() // LEVEL + 1 - val (nextState, isEnabled) = - applyTransition(stepNo, erased, transition, transitionNo, checkForErrors = true) - rewriter.exprCache.disposeActionLevel() // leave only the constants - rewriter.pop() // forget all the constraints that were generated by the transition, LEVEL + 0 - (tranWithNo, isEnabled) +: filterEnabled(state, tail) - } - } - } - - val savedVarTypes = rewriter.typeFinder.varTypes // save the variable types before applying the transitions - val enabled = filterEnabled(startingState, transitionsAndNos) - /* - enabledList :+= enabled map (_._2) // put it on stack, FIXME: this will not work well with DFS... - dumpEnabledMap() - */ - // restore the variable types to apply the enabled transitions once again - rewriter.typeFinder.reset(savedVarTypes) - enabled - } - - private def applyEnabled(stepNo: Int, startingState: SymbState, - transWithEnabled: List[((TlaEx, Int), Boolean)]): Option[SymbState] = { - // second, apply the enabled transitions and collect their effects - logger.info( - "Step %d, level %d: collecting %d enabled transition(s)" - .format(stepNo, rewriter.contextLevel, transWithEnabled.count(_._2))) - assert(transWithEnabled.nonEmpty) - val simplifier = new ConstSimplifierForSmt() - - def applyTrans(state: SymbState, ts: List[((TlaEx, Int), Boolean)]): List[SymbState] = - ts match { - case List() => - List(state) // the final state may contain additional cells, add it - - case (tranWithNo, isEnabled) :: tail => - if (!isEnabled && !learnTransFromUnsat) { - applyTrans(state, tail) // ignore the disabled transition, without any rewriting - } else { - val (transition, transitionNo) = tranWithNo - val erased = state.setBinding(forgetPrimed(state.binding)) - // note that the constraints are added at the current level, without an additional push - var (nextState, _) = applyTransition(stepNo, erased, transition, transitionNo, checkForErrors = false) - rewriter.exprCache.disposeActionLevel() // leave only the constants - if (isEnabled && learnInvFromUnsat && invariantSplitByTransition) { - nextState = assumeInvariant(stepNo, nextState) - } - if (isEnabled) { - // collect the variables of the enabled transition - nextState +: applyTrans(nextState, tail) - } else { - assert(learnTransFromUnsat) - // Do not collect the variables from the disabled transition, but remember that it was disabled. - // Note that the constraints are propagated via nextState - solverContext.assertGroundExpr(simplifier.simplifyShallow(tla.not(nextState.ex))) - applyTrans(nextState, tail) - } - } - } - - val nextAndLastStates = applyTrans(startingState, transWithEnabled) - val lastState = nextAndLastStates.last - val nextStates = nextAndLastStates.slice(0, nextAndLastStates.length - 1) - - val picker = new CherryPick(rewriter) - // pick an index j \in { 0..k } of the fired transition - val (oracleState, oracle) = picker.oracleFactory.newDefaultOracle(lastState, nextStates.length) - - if (nextStates.isEmpty) { - throw new IllegalArgumentException("enabled must be non-empty") - } else if (nextStates.lengthCompare(1) == 0) { - val resultingState = oracleState.setBinding(shiftBinding(lastState.binding, constants)) - solverContext.assertGroundExpr(lastState.ex) - if (!invariantSplitByTransition) { checkAllInvariants(stepNo, 0, resultingState) } - stack +:= (resultingState, oracle) // in this case, oracle is always zero - shiftTypes(constants) - typesStack = typeFinder.varTypes +: typesStack - Some(resultingState) - } else { - // if oracle = i, then the ith transition is enabled - solverContext.assertGroundExpr(oracle.caseAssertions(oracleState, nextStates.map(_.ex))) - - // glue the computed states S_0, ..., S_k together: - // for every variable x', pick c_x from { S_1[x'], ..., S_k[x'] } - // and require \A i \in { 0.. k-1}. j = i => c_x = S_i[x'] - // Then, the final state binds x' -> c_x for every x' \in Vars' - def getAssignedVars(st: SymbState) = forgetNonPrimed(st.binding, Set()).toMap.keySet - - val primedVars = getAssignedVars(nextStates.head) // only VARIABLES, not CONSTANTS - var finalState = oracleState - if (nextStates.exists(getAssignedVars(_) != primedVars)) { - val index = nextStates.indexWhere(getAssignedVars(_) != primedVars) - val otherSet = getAssignedVars(nextStates(index)) - val diff = otherSet.union(primedVars).diff(otherSet.intersect(primedVars)) - val msg = - "[Step %d] Next states 0 and %d disagree on the set of assigned variables: %s" - .format(stepNo, index, diff.mkString(", ")) - throw new InternalCheckerError(msg, finalState.ex) - } - - def pickVar(x: String): ArenaCell = { - val toPickFrom = nextStates map (_.binding(x)) - finalState = picker.pickByOracle(finalState, oracle, toPickFrom, - finalState.arena.cellFalse().toNameEx) // no else case - finalState.asCell - } - - val finalVarBinding = Binding(primedVars.toSeq map (n => (n, pickVar(n))): _*) // variables only - val constBinding = oracleState.binding.toMap.filter(p => constants.contains(p._1)) - finalState = finalState.setBinding(Binding(finalVarBinding.toMap ++ constBinding)) - if (debug && !solverContext.sat()) { - throw new InternalCheckerError(s"Error picking next variables (step $stepNo). Report a bug.", finalState.ex) - } - // check the invariant, if search invariant.split=false - if (!invariantSplitByTransition) { checkAllInvariants(stepNo, 0, finalState) } - if (learnInvFromUnsat && !invariantSplitByTransition) { - finalState = assumeInvariant(stepNo, finalState) - } - // finally, shift the primed variables to non-primed - finalState = finalState.setBinding(shiftBinding(finalState.binding, constants)) - // that is the result of this step - stack +:= (finalState, oracle) - // here we save the transition index, not the oracle, which will be shown to the user - shiftTypes(constants) - typesStack = typeFinder.varTypes +: typesStack - Some(finalState) - } - } - - // This method adds constraints right in the current context, without doing push - private def applyTransition(stepNo: Int, state: SymbState, transition: TlaEx, transitionNo: Int, - checkForErrors: Boolean): (SymbState, Boolean) = { - logger.debug( - "Step #%d, transition #%d, SMT context level %d." - .format(stepNo, transitionNo, rewriter.contextLevel)) - logger.debug("Finding types of the variables...") - checkTypes(transition) - solverContext.log( - "; ------- STEP: %d, STACK LEVEL: %d TRANSITION: %d {" - .format(stepNo, rewriter.contextLevel, transitionNo)) - logger.debug("Applying rewriting rules...") - var nextState = rewriter.rewriteUntilDone(state.setRex(transition)) - rewriter.flushStatistics() - if (checkForErrors && debug) { - // This is a debugging feature that allows us to find incorrect rewriting rules. - // Disable it in production. - logger.debug("Finished rewriting. Checking satisfiability of the pushed constraints.") - solverContext.satOrTimeout(transitionTimeout) match { - case Some(false) => - // this is a clear sign of a bug in one of the translation rules - logger.debug("UNSAT after pushing transition constraints") - throw new CheckerException("A contradiction introduced in rewriting. Report a bug.", state.ex) - - case Some(true) => - () // SAT - logger.debug("The transition constraints are OK.") - - case None => // interpret it as sat - logger.debug("Timeout. Assuming the transition constraints to be OK.") - } - } - if (!checkForErrors) { - // this was an experimental feature, which did not work nicely - // assume no failure occurs - // val failPreds = state.arena.findCellsByType(FailPredT()) - // failPreds.map(fp => tla.not(fp.toNameEx)) foreach solverContext.assertGroundExpr - // just return the state - (nextState, true) - // LEVEL + 0 - } else { - rewriter.push() // LEVEL + 1 - // assume the constraint constructed by this transition - solverContext.assertGroundExpr(nextState.ex) - // check whether this transition violates some assertions - logger.debug("Checking transition feasibility.") - solverContext.satOrTimeout(transitionTimeout) match { - case Some(true) => - if (invariantSplitByTransition) { - // check the invariant as soon as one transition has been applied - checkAllInvariants(stepNo, transitionNo, nextState) - } - // and then forget all these constraints! - rewriter.pop() // LEVEL + 0 - solverContext.log("; } ------- STEP: %d, STACK LEVEL: %d".format(stepNo, rewriter.contextLevel)) - (nextState, true) - // LEVEL + 0 - - case r: Option[Boolean] => // unsat or timeout - // the current symbolic state is not feasible - if (r.isDefined) { - logger.debug("Transition #%d is not feasible.".format(transitionNo)) - } else { - logger.debug( - s"Timed out when checking feasibility of transition #$transitionNo. Assuming it is infeasible.") - } - rewriter.pop() // LEVEL + 0 - solverContext.log( - "; } ------- STEP: %d, STACK LEVEL: %d TRANSITION: %d" - .format(stepNo, rewriter.contextLevel, transitionNo)) - (nextState, false) - // LEVEL + 0 - } - } - } - - private def assumeInvariant(stepNo: Int, state: SymbState): SymbState = { - val matchesInvFilter = invFilter == "" || stepNo.toString.matches("^" + invFilter + "$") - if (!matchesInvFilter || checkerInput.invariantsAndNegations.isEmpty) { - state - } else { - // as we have checked the invariant, we assume that it holds - val savedEx = state.ex - val savedTypes = rewriter.typeFinder.varTypes - val savedBinding = state.binding - // rename x' to x, so we are reasoning about the non-primed variables - shiftTypes(constants) - var nextState = state.setBinding(shiftBinding(state.binding, constants)) - - for (((inv, _), index) <- checkerInput.invariantsAndNegations.zipWithIndex) { - typeFinder.inferAndSave(inv) - logger.debug(s"Assuming that the invariant $index holds true") - nextState = rewriter.rewriteUntilDone(nextState.setRex(inv)) - // assume that the invariant holds true - solverContext.assertGroundExpr(nextState.ex) - } - - // restore the expression, the types, and the bindings - rewriter.typeFinder.reset(savedTypes) // forget about the types that were used to check the invariant - nextState.setRex(savedEx).setBinding(savedBinding) - } - } - - private def checkAllInvariants(stepNo: Int, transitionNo: Int, nextState: SymbState): Unit = { - val matchesInvFilter = invFilter == "" || stepNo.toString.matches("^" + invFilter + "$") - if (!matchesInvFilter) { - return // skip the check if this transition should not be checked - } - - // if the previous step was filtered, we cannot use the unchanged optimization - // Bugfix to #108: never filter out the initial step - val prevMatchesInvFilter = - stepNo > 0 && (invFilter == "" || (stepNo - 1).toString.matches("^" + invFilter + "$")) - - val invNegs = checkerInput.invariantsAndNegations.map(_._2) - for ((notInv, invNo) <- invNegs.zipWithIndex) { - logger.debug(s"Checking the invariant $invNo") - val changedPrimed = - if (prevMatchesInvFilter) { - nextState.changed // only check the invariant if it touches the changed variables - } else { - nextState.binding.toMap.keySet // check the invariant in any case, as it could be violated at the previous step - } - val savedTypes = rewriter.typeFinder.varTypes - // rename x' to x, so we are reasoning about the non-primed variables - shiftTypes(constants) - val shiftedState = nextState.setBinding(shiftBinding(nextState.binding, constants)) - rewriter.exprCache.disposeActionLevel() // renaming x' to x makes the cache inconsistent, so clean it - // check the types and the invariant - checkTypes(notInv) - checkOneInvariant(stepNo, transitionNo, shiftedState, changedPrimed, notInv) - rewriter.typeFinder.reset(savedTypes) // forget about the types that were used to check the invariant - } - } - - private def checkOneInvariant(stepNo: Int, transitionNo: Int, nextState: SymbState, changedPrimed: Set[String], - notInv: TlaEx): Unit = { - val used = - TlaExUtil.findUsedNames(notInv).map(_ + "'") // add primes as the invariant is referring to non-primed variables - if (stepNo != 0 && used.nonEmpty && used.intersect(changedPrimed).isEmpty) { - // another bugfix: look at unchecked variables, except for the case when Init has been applied! - // bugfix for #108: check the invariant over CONSTANTS, if it has not been changed before - // XXX: it might happen that an invariant over CONSTANTS is checked multiple times. We will fix that in v0.8.0. - logger.debug(s"The invariant is referring only to the UNCHANGED variables. Skipped.") - } else { - rewriter.push() - val notInvState = rewriter.rewriteUntilDone( - nextState - .setRex(notInv)) - solverContext.assertGroundExpr(notInvState.ex) - solverContext.satOrTimeout(invariantTimeout) match { - case Some(true) => - // introduce a dummy oracle to hold the transition index, we need it for the counterexample - val oracle = new MockOracle(transitionNo) - stack = (notInvState, oracle) +: stack - val filenames = dumpCounterexample(notInv) - logger.error( - s"Invariant is violated at depth $stepNo. Check the counterexample in any of ${filenames.mkString(", ")}") - if (debug) { - logger.warn(s"Dumping the arena into smt.log. This may take some time...") - // dump everything in the log - val writer = new StringWriter() - new SymbStateDecoder(solverContext, rewriter).dumpArena(notInvState, new PrintWriter(writer)) - solverContext.log(writer.getBuffer.toString) - } - // cancel the search - throw new CancelSearchException(Outcome.Error) - - case Some(false) => - logger.debug("The invariant holds true.") - - case None => - logger.debug("Timeout. Assuming that the invariant holds true.") - } - rewriter.pop() - } - } - - // returns a list of files with counterexample written - private def dumpCounterexample(notInv: TlaEx): List[String] = { - val states = new ListBuffer[NextState]() - for (((state, oracle), i) <- stack.reverse.zipWithIndex) { - val decoder = new SymbStateDecoder(solverContext, rewriter) - val transition = oracle.evalPosition(solverContext, state) - val binding = decoder.decodeStateVariables(state) - states += ((transition.toString, binding)) - } - CounterexampleWriter.writeAllFormats(checkerInput.rootModule, notInv, states.toList) - } - - private def checkTypes(expr: TlaEx): Unit = { - typeFinder.inferAndSave(expr) - if (typeFinder.typeErrors.nonEmpty) { - def print_error(e: TypeInferenceError): Unit = { - val sourceLocator: SourceLocator = SourceLocator(sourceStore.makeSourceMap, changeListener) - - val locInfo = - sourceLocator.sourceOf(e.origin) match { - case Some(loc) => loc.toString - case None => "" - } - val exStr = e.origin match { - case OperEx(op, _*) => op.name + "(...)" - case ex @ _ => ex.toString() - } - logger.error("%s, %s, type error: %s".format(locInfo, exStr, e.explanation)) - } - - typeFinder.typeErrors foreach print_error - throw new CancelSearchException(Outcome.Error) - } - } - - /** - * Remove the non-primed variables (except provided constants) - * and rename the primed variables to their non-primed versions. - * After that, remove the type finder to contain the new types only. - */ - private def shiftTypes(constants: Set[String]): Unit = { - val types = typeFinder.varTypes - val nextTypes = - types - .filter(p => p._1.endsWith("'") || constants.contains(p._1)) - .map(p => (p._1.stripSuffix("'"), p._2)) - typeFinder.reset(nextTypes) - } - - private def forgetPrimedTypes(): Unit = { - val types = typeFinder.varTypes - val unprimedTypes = types.filter(!_._1.endsWith("'")) - typeFinder.reset(unprimedTypes) - } - - // remove non-primed variables and rename primed variables to non-primed - private def shiftBinding(binding: Binding, constants: Set[String]): Binding = { - Binding(forgetNonPrimed(binding, constants).toMap - .map(p => (p._1.stripSuffix("'"), p._2))) - } - - // remove primed variables - private def forgetPrimed(binding: Binding): Binding = { - Binding(binding.toMap.filter(p => !p._1.endsWith("'"))) - } - - // remove non-primed variables, except the provided constants - private def forgetNonPrimed(binding: Binding, constants: Set[String]): Binding = { - Binding(binding.toMap.filter(p => p._1.endsWith("'") || constants.contains(p._1))) - } - - // does the transition number satisfy the given filter at the given step? - private def stepMatchesFilter(stepNo: Int, transitionNo: Int): Boolean = { - if (stepFilters.size <= stepNo) { - true // no filter applied - } else { - transitionNo.toString.matches("^%s$".format(stepFilters(stepNo))) - } - } - - private def dumpEnabledMap(): Unit = { - val filename = "enabled-map.txt" - val writer = new PrintWriter(new FileWriter(filename, false)) - val transitionsCnt = checkerInput.nextTransitions.size - writer.println("The map of enabled transitions:") - val hrule = "----%s".format((0 until transitionsCnt map (_ => "-")) mkString "") - writer.println(hrule) - writer.println(" %s".format(0 until transitionsCnt map (i => ((i / 100) % 10).toString) mkString "")) - writer.println(" %s".format(0 until transitionsCnt map (i => ((i / 10) % 10).toString) mkString "")) - writer.println(" %s".format(0 until transitionsCnt map (i => (i % 10).toString) mkString "")) - writer.println(hrule) - for ((enabled, stepNo) <- enabledList.zipWithIndex) { - val set = Set(enabled: _*) - val line = 0 until transitionsCnt map (i => if (set.contains(i)) "*" else " ") mkString "" - writer.println(s"%3d:%s".format(stepNo, line)) - } - writer.println(hrule) - writer.close() - } - - private def printRewriterSourceLoc(): Unit = { - // def getSourceLocation(ex: TlaEx) = sourceStore.find(ex.ID) - def getSourceLocation(ex: TlaEx) = { - val sourceLocator: SourceLocator = SourceLocator( - sourceStore.makeSourceMap, - changeListener - ) - sourceLocator.sourceOf(ex) - } - - rewriter.getRewritingStack().find(getSourceLocation(_).isDefined) match { - case None => - logger.error("Unable find the source of the problematic expression") - - case Some(ex) => - val loc = getSourceLocation(ex).get - logger.error(s"The problem occurs around the source location $loc") - } - } -} diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/SymbStateDecoder.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/SymbStateDecoder.scala index 5089060d51..440c4f609b 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/SymbStateDecoder.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/SymbStateDecoder.scala @@ -1,15 +1,15 @@ package at.forsyte.apalache.tla.bmcmt import java.io.PrintWriter - -import at.forsyte.apalache.tla.bmcmt.implicitConversions._ import at.forsyte.apalache.tla.bmcmt.smt.SolverContext import at.forsyte.apalache.tla.bmcmt.types._ import at.forsyte.apalache.tla.lir.convenience.tla import at.forsyte.apalache.tla.lir.oper.{TlaFunOper, TlaSetOper} import at.forsyte.apalache.tla.lir.values._ import at.forsyte.apalache.tla.lir._ -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.lir.TypedPredefs._ +import at.forsyte.apalache.tla.lir.UntypedPredefs.BuilderExAsUntyped +import at.forsyte.apalache.tla.lir.convenience.tla.fromTlaEx import com.typesafe.scalalogging.LazyLogging import scala.collection.immutable.{HashSet, SortedSet} @@ -30,8 +30,9 @@ class SymbStateDecoder(solverContext: SolverContext, rewriter: SymbStateRewriter // compute the equivalence classes for the cells, totally suboptimally // TODO: rewrite, I did not think too much at all def iseq(c: ArenaCell, d: ArenaCell): Boolean = { + val query = tla.eql(c.toNameEx, d.toNameEx).untyped() c.cellType == d.cellType && solverContext - .evalGroundExpr(tla.eql(c.toNameEx, d.toNameEx)) == tla.bool(true).untyped() + .evalGroundExpr(query) == tla.bool(true).untyped() } def merge(cls: List[HashSet[ArenaCell]], c: ArenaCell, d: ArenaCell): List[HashSet[ArenaCell]] = { @@ -68,13 +69,20 @@ class SymbStateDecoder(solverContext: SolverContext, rewriter: SymbStateRewriter } def decodeCellToTlaEx(arena: Arena, cell: ArenaCell): TlaEx = cell.cellType match { - case BoolT() | IntT() | FailPredT() => - solverContext.evalGroundExpr(cell.toNameEx) + case BoolT() => + solverContext.evalGroundExpr(cell.toNameEx).withTag(Typed(BoolT1())) + + case IntT() => + solverContext.evalGroundExpr(cell.toNameEx).withTag(Typed(IntT1())) + + case FailPredT() => + // FailPred will be removed soon, see: https://github.com/informalsystems/apalache/issues/665 + solverContext.evalGroundExpr(cell.toNameEx).withTag(Typed(BoolT1())) case ConstT() => val found = rewriter.strValueCache.findKey(cell) if (found.isDefined) { - ValEx(TlaStr(found.get)) + tla.str(found.get).typed() } else { findCellInSet(arena, rewriter.strValueCache.values().toSeq, cell.toNameEx) match { // found among the cached keys @@ -83,47 +91,89 @@ class SymbStateDecoder(solverContext: SolverContext, rewriter: SymbStateRewriter case None => // not found, just use the name - ValEx(TlaStr(cell.toString)) // a value that was assigned by the solver, and not created by us + // a value that was assigned by the solver, and not created by us + tla.str(cell.toString).typed() } } case UnknownT() => - tla.appFun(NameEx("Unknown"), cell.toNameEx) + throw new IllegalStateException(s"Found cell $cell of cell type Unknown") + + case setT @ FinSetT(elemT) => + val setT1 = setT.toTlaType1 - case FinSetT(_) => - def inSet(e: ArenaCell) = - solverContext.evalGroundExpr(tla.in(e.toNameEx, cell.toNameEx)) == tla.bool(true).untyped() + def inSet(e: ArenaCell) = { + val mem = tla + .in(fromTlaEx(e.toNameEx).typed(elemT.toTlaType1), fromTlaEx(cell.toNameEx).typed(setT.toTlaType1)) + .typed(BoolT1()) + solverContext.evalGroundExpr(mem) == tla.bool(true).typed() + } val elems = arena.getHas(cell).filter(inSet) val decodedElems = elems.map(decodeCellToTlaEx(arena, _)) // try to normalize the set for better user experience val niceElems = decodedElems.distinct.sortWith(SymbStateDecoder.compareTlaExByStr) - tla.enumSet(niceElems: _*) - - case FinFunSetT(_, _) => - tla.funSet(decodeCellToTlaEx(arena, arena.getDom(cell)), decodeCellToTlaEx(arena, arena.getCdm(cell))) + tla + .enumSet(niceElems: _*) + .typed(setT1) + + case ffsT @ FinFunSetT(_, _) => + val fromSet = decodeCellToTlaEx(arena, arena.getDom(cell)) + val toSet = decodeCellToTlaEx(arena, arena.getCdm(cell)) + tla + .funSet(fromSet, toSet) + .typed(ffsT.toTlaType1) + + case funT @ FunT(_, _) => + val funT1 = funT.toTlaType1.asInstanceOf[FunT1] + + def appendPair(fun: TlaEx, key: ArenaCell, value: ArenaCell): TlaEx = { + val pair = tla + .colonGreater(decodeCellToTlaEx(arena, key), decodeCellToTlaEx(arena, value)) + .typed(FunT1(funT1.arg, funT1.res)) + tla + .atat(fun, pair) + .typed(funT1) + } - case FunT(_, _) => // in the new implementation, every function is represented with the relation {(x, f[x]) : x \in S} val relation = arena.getCdm(cell) - val args = - decodeCellToTlaEx(arena, relation) match { - case OperEx(TlaSetOper.enumSet, elems @ _*) => - def untuple(e: TlaEx) = e match { - case OperEx(TlaFunOper.tuple, pair @ _*) => - pair - - case _ => throw new RewriterException("Corrupted function: " + relation, NullEx) - } - elems flatMap untuple - - case _ => throw new RewriterException("Corrupted function: " + relation, NullEx) - } + def isInRelation(pair: ArenaCell): Boolean = { + val mem = tla + .in(fromTlaEx(pair.toNameEx).typed(funT1.arg), + fromTlaEx(relation.toNameEx).typed(TupT1(funT1.arg, funT1.res))) + .typed(BoolT1()) + solverContext.evalGroundExpr(mem) == tla.bool(true).typed(BoolT1()) + } - tla.atat(args: _*) + val pairs = arena.getHas(relation) + pairs find isInRelation match { + case None => + // this is a pathological case, produce: [ x \in {} |-> x ] + tla + .funDef(tla.name("x").typed(funT1.arg), tla.name("x").typed(funT1.arg), + tla.enumSet().typed(SetT1(funT1.res))) + .typed(funT1) + + case Some(first) => + // this is the common case + val head = arena.getHas(first) + val firstPair = tla + .colonGreater(decodeCellToTlaEx(arena, head(0)), decodeCellToTlaEx(arena, head(1))) + .typed(FunT1(funT1.arg, funT1.res)) + pairs.tail.foldLeft(firstPair) { case (f, p) => + if (p == first) { + f + } else { + val pair = arena.getHas(p) + appendPair(f, pair(0), pair(1)) + } + } + } - case SeqT(_) => + case SeqT(elemT) => + val elemT1 = elemT.toTlaType1 val startEndFun = arena.getHas(cell) map (decodeCellToTlaEx(arena, _)) startEndFun match { case ValEx(TlaInt(start)) :: ValEx(TlaInt(end)) +: cells => @@ -131,7 +181,8 @@ class SymbStateDecoder(solverContext: SolverContext, rewriter: SymbStateRewriter def isIn(elem: TlaEx, no: Int): Boolean = no >= start && no < end val filtered = cells.zipWithIndex filter (isIn _).tupled map (_._1) - tla.tuple(filtered: _*) // return a tuple as it is the canonical representation of a sequence + // return a tuple as it is the canonical representation of a sequence + tla.tuple(filtered: _*).typed(SeqT1(elemT1)) case _ => throw new RewriterException("Corrupted sequence: " + startEndFun, NullEx) } @@ -155,33 +206,34 @@ class SymbStateDecoder(solverContext: SolverContext, rewriter: SymbStateRewriter } else { val index = keyList.indexOf(key) val valueCell = fieldValues(index) - ValEx(TlaStr(key)) +: decodeCellToTlaEx(arena, valueCell) +: es + tla.str(key).typed() +: decodeCellToTlaEx(arena, valueCell) +: es } } val keysAndValues = keyList.reverse.foldLeft(List[TlaEx]())(eachField) if (keysAndValues.nonEmpty) { - OperEx(TlaFunOper.enum, keysAndValues: _*) + OperEx(TlaFunOper.enum, keysAndValues: _*)(Typed(r.toTlaType1)) } else { logger.error( s"Decoder: Found an empty record $cell when decoding a counterexample, domain = $domCell. This is a bug.") // for debugging purposes, just return a string - ValEx(TlaStr(s"Empty record domain $domCell")) + tla.str(s"Empty record domain $domCell").typed() } case t @ TupleT(_) => val tupleElems = arena.getHas(cell) val elemAsExprs = tupleElems.map(c => decodeCellToTlaEx(arena, c)) - tla.tuple(elemAsExprs: _*) + tla.tuple(elemAsExprs: _*).typed(t.toTlaType1) case PowSetT(t @ FinSetT(_)) => - tla.powSet(decodeCellToTlaEx(arena, arena.getDom(cell))) + val baseset = decodeCellToTlaEx(arena, arena.getDom(cell)) + tla.powSet(baseset).typed(SetT1(TlaType1.fromTypeTag(baseset.typeTag))) - case InfSetT(elemT) if cell == arena.cellIntSet() => - ValEx(TlaIntSet) + case InfSetT(_) if cell == arena.cellIntSet() => + tla.intSet().typed(SetT1(IntT1())) - case InfSetT(elemT) if cell == arena.cellNatSet() => - ValEx(TlaNatSet) + case InfSetT(_) if cell == arena.cellNatSet() => + tla.natSet().typed(SetT1(IntT1())) case _ => throw new RewriterException("Don't know how to decode the cell %s of type %s".format(cell, cell.cellType), NullEx) @@ -199,7 +251,8 @@ class SymbStateDecoder(solverContext: SolverContext, rewriter: SymbStateRewriter private def findCellInSet(arena: Arena, cells: Seq[ArenaCell], ex: TlaEx): Option[ArenaCell] = { def isEq(c: ArenaCell): Boolean = { - ValEx(TlaBool(true)) == solverContext.evalGroundExpr(tla.and(tla.eql(c.toNameEx, ex))) + val query = tla.and(tla.eql(c.toNameEx, ex)) + tla.bool(true).typed() == solverContext.evalGroundExpr(query.untyped()) } cells.find(isEq) @@ -208,10 +261,14 @@ class SymbStateDecoder(solverContext: SolverContext, rewriter: SymbStateRewriter def reverseMapVar(expr: TlaEx, varName: String, cell: ArenaCell): TlaEx = { def rec(ex: TlaEx): TlaEx = ex match { case NameEx(name) => - if (name != cell.toNameEx.name) ex else NameEx(varName) + if (name != cell.toNameEx.name) { + ex + } else { + tla.name(varName).untyped() + } case OperEx(oper, args @ _*) => - OperEx(oper, args map rec: _*) + OperEx(oper, args map rec: _*)(Untyped()) case _ => ex diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/SymbStateRewriter.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/SymbStateRewriter.scala index 2f0e8c8644..96dcae1456 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/SymbStateRewriter.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/SymbStateRewriter.scala @@ -5,7 +5,6 @@ import at.forsyte.apalache.tla.bmcmt.analyses.{ExprGradeStore, FormulaHintsStore import at.forsyte.apalache.tla.bmcmt.caches.{ExprCache, IntValueCache, RecordDomainCache, StrValueCache} import at.forsyte.apalache.tla.bmcmt.rewriter.{Recoverable, RewriterConfig, SymbStateRewriterSnapshot} import at.forsyte.apalache.tla.bmcmt.smt.SolverContext -import at.forsyte.apalache.tla.bmcmt.types.{CellT, TypeFinder} import at.forsyte.apalache.tla.lir.TlaEx /** @@ -50,13 +49,6 @@ trait SymbStateRewriter extends StackableContext with MessageStorage with Recove */ def config: RewriterConfig - /** - * A type finder. - * - * @return a type finder that can produce cell types - */ - def typeFinder: TypeFinder[CellT] - /** * The cache for lazy equalities, to avoid generating the same equality constraints many times. */ diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/SymbStateRewriterAuto.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/SymbStateRewriterAuto.scala index ff69296e40..56796f3510 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/SymbStateRewriterAuto.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/SymbStateRewriterAuto.scala @@ -4,8 +4,6 @@ import at.forsyte.apalache.tla.bmcmt.analyses._ import at.forsyte.apalache.tla.bmcmt.caches.{ExprCache, IntValueCache, RecordDomainCache, StrValueCache} import at.forsyte.apalache.tla.bmcmt.rewriter.{RewriterConfig, SymbStateRewriterSnapshot} import at.forsyte.apalache.tla.bmcmt.smt.SolverContext -import at.forsyte.apalache.tla.bmcmt.types.CellT -import at.forsyte.apalache.tla.bmcmt.types.eager.TrivialTypeFinder import at.forsyte.apalache.tla.lir.TlaEx /** @@ -31,8 +29,6 @@ class SymbStateRewriterAuto(private var _solverContext: SolverContext) extends S var config: RewriterConfig = new RewriterConfig - val typeFinder = new TrivialTypeFinder() - /** * A solver context that is populated by the rewriter. */ @@ -50,7 +46,7 @@ class SymbStateRewriterAuto(private var _solverContext: SolverContext) extends S private val exprGradeStoreImpl = new ExprGradeStoreImpl() private val exprGradeAnalysis = new ExprGradeAnalysis(exprGradeStoreImpl) - private val impl = new SymbStateRewriterImpl(solverContext, typeFinder, exprGradeStore) + private val impl = new SymbStateRewriterImpl(solverContext, exprGradeStore) override def contextLevel: Int = impl.contextLevel @@ -68,21 +64,10 @@ class SymbStateRewriterAuto(private var _solverContext: SolverContext) extends S override def exprGradeStore: ExprGradeStore = exprGradeStoreImpl - private def reset(arena: Arena, binding: Binding): Unit = { - def add(m: Map[String, CellT], c: ArenaCell) = m + (c.toString -> c.cellType) - val cellTypes = arena.cellMap.values.foldLeft(Map[String, CellT]())(add) - def addName(m: Map[String, CellT], p: (String, ArenaCell)) = m + (p._1 -> p._2.cellType) - val cellAndBindingTypes = binding.toMap.foldLeft(cellTypes)(addName) - // propagate cell types and bindings to the type inference engine - typeFinder.reset(cellAndBindingTypes) - } + private def reset(arena: Arena, binding: Binding): Unit = {} private def preprocess(ex: TlaEx): Unit = { exprGradeAnalysis.labelExpr(consts, vars, ex) - typeFinder.inferAndSave(ex) - if (typeFinder.typeErrors.nonEmpty) { - throw new TypeInferenceException(typeFinder.typeErrors) - } } override def rewriteOnce(state: SymbState): SymbStateRewriter.RewritingResult = { diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/SymbStateRewriterImpl.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/SymbStateRewriterImpl.scala index 530165bd1c..bfaff07068 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/SymbStateRewriterImpl.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/SymbStateRewriterImpl.scala @@ -9,8 +9,6 @@ import at.forsyte.apalache.tla.bmcmt.rewriter.{ } import at.forsyte.apalache.tla.bmcmt.rules._ import at.forsyte.apalache.tla.bmcmt.smt.SolverContext -import at.forsyte.apalache.tla.bmcmt.types.eager.TrivialTypeFinder -import at.forsyte.apalache.tla.bmcmt.types.{CellT, TypeFinder} import at.forsyte.apalache.tla.lir._ import at.forsyte.apalache.tla.lir.convenience.tla import at.forsyte.apalache.tla.lir.oper._ @@ -30,13 +28,12 @@ import scala.collection.mutable *

TODO: rename this class to RewriterImpl?

* * @param _solverContext a fresh solver context that will be populated with constraints - * @param typeFinder a type finder (assuming that typeFinder.inferAndSave has been called already) * @param exprGradeStore a labeling scheme that associated a grade with each expression; * it is required to distinguish between state-level and action-level expressions. * @param profilerListener optional listener that is used to profile the rewriting rules * @author Igor Konnov */ -class SymbStateRewriterImpl(private var _solverContext: SolverContext, var typeFinder: TypeFinder[CellT], +class SymbStateRewriterImpl(private var _solverContext: SolverContext, val exprGradeStore: ExprGradeStore = new ExprGradeStoreImpl(), val profilerListener: Option[MetricProfilerListener] = None) extends SymbStateRewriter with Serializable with Recoverable[SymbStateRewriterSnapshot] { @@ -384,11 +381,14 @@ class SymbStateRewriterImpl(private var _solverContext: SolverContext, var typeF // use cache or compute a new expression exprCache.get(state.ex) match { - case Some(eg: (TlaEx, ExprGrade.Value)) => + case Some(eg: (TlaEx, ExprGrade.Value)) if eg._1.typeTag == state.ex.typeTag => + // In rare cases, the expression may be equal to a cache expression, but they may have different types. + // For instance, {}: Set(Int) and {}: Set(Set(Int)) are syntactically the same but having different types. + // Hence, we compare types as well. As this case is rare, we don't store the types directly in the cache. solverContext.log(s"; Using cached value ${eg._1} for expression ${state.ex}") state.setRex(eg._1) - case None => + case _ => // Get the SMT metrics before translating the expression. // Note that we are not doing that in the recursive function, // as the new expressions there will not have source information. @@ -452,8 +452,8 @@ class SymbStateRewriterImpl(private var _solverContext: SolverContext, var typeF * @return the snapshot */ override def snapshot(): SymbStateRewriterSnapshot = { - new SymbStateRewriterSnapshot(typeFinder.asInstanceOf[TrivialTypeFinder].snapshot(), intValueCache.snapshot(), - intRangeCache.snapshot(), strValueCache.snapshot(), recordDomainCache.snapshot(), exprCache.snapshot()) + new SymbStateRewriterSnapshot(intValueCache.snapshot(), intRangeCache.snapshot(), strValueCache.snapshot(), + recordDomainCache.snapshot(), exprCache.snapshot()) } /** @@ -463,7 +463,6 @@ class SymbStateRewriterImpl(private var _solverContext: SolverContext, var typeF * @param shot a snapshot */ override def recover(shot: SymbStateRewriterSnapshot): Unit = { - typeFinder.asInstanceOf[TrivialTypeFinder].recover(shot.typeFinderSnapshot) intValueCache.recover(shot.intValueCacheSnapshot) intRangeCache.recover(shot.intRangeCacheSnapshot) strValueCache.recover(shot.strValueCacheSnapshot) diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/VCGenerator.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/VCGenerator.scala index 12f239cbe5..f8821b7358 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/VCGenerator.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/VCGenerator.scala @@ -1,14 +1,16 @@ package at.forsyte.apalache.tla.bmcmt -import at.forsyte.apalache.tla.lir._ +import at.forsyte.apalache.tla.lir.{BoolT1, _} import at.forsyte.apalache.tla.lir.convenience.tla import at.forsyte.apalache.tla.lir.oper.TlaBoolOper import at.forsyte.apalache.tla.lir.transformations.TransformationTracker import at.forsyte.apalache.tla.lir.transformations.standard.DeepCopy import at.forsyte.apalache.tla.pp.NormalizedNames -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import TypedPredefs._ import com.typesafe.scalalogging.LazyLogging +import scala.annotation.tailrec + /** * Generator of verification conditions. In the current implementation, VCGenerator takes an invariant candidate, * decomposes the invariant into smaller invariant candidates and produces negations of the invariant candidates. @@ -44,13 +46,15 @@ class VCGenerator(tracker: TransformationTracker) extends LazyLogging { } private def introConditions(inputInv: TlaEx): Seq[TlaOperDecl] = { - def mapToDecls(smallInv: TlaEx, index: Int): Seq[TlaOperDecl] = { + def mapToDecls(invPiece: TlaEx, index: Int): Seq[TlaOperDecl] = { val deepCopy = DeepCopy(tracker) - val smallInvCopy = deepCopy.deepCopyEx(smallInv) + val invPieceCopy = deepCopy.deepCopyEx(invPiece) + val tag = inputInv.typeTag val positive = - TlaOperDecl(NormalizedNames.VC_INV_PREFIX + index, List(), smallInvCopy)(Untyped()) + TlaOperDecl(NormalizedNames.VC_INV_PREFIX + index, List(), invPieceCopy)(tag) + val notInvPieceCopy = tla.not(invPieceCopy).typed(BoolT1()) val negative = - TlaOperDecl(NormalizedNames.VC_NOT_INV_PREFIX + index, List(), tla.not(smallInvCopy))(Untyped()) + TlaOperDecl(NormalizedNames.VC_NOT_INV_PREFIX + index, List(), notInvPieceCopy)(tag) Seq(positive, negative) } @@ -59,11 +63,11 @@ class VCGenerator(tracker: TransformationTracker) extends LazyLogging { fragments.zipWithIndex.flatMap { case (e, i) => mapToDecls(e, i) } } - private def splitInv(universalsRev: Seq[(String, TlaEx)], inv: TlaEx): Seq[TlaEx] = { + private def splitInv(universalsRev: Seq[(NameEx, TlaEx)], inv: TlaEx): Seq[TlaEx] = { inv match { - case OperEx(TlaBoolOper.forall, NameEx(varName), set, pred) => + case OperEx(TlaBoolOper.forall, nameEx @ NameEx(_), set, pred) => // \A x \in S: B /\ C is equivalent to (\A x \in S: B) /\ (\A x \in S: C) - splitInv((varName, set) +: universalsRev, pred) + splitInv((nameEx, set) +: universalsRev, pred) case OperEx(TlaBoolOper.and, args @ _*) => // we split A /\ B into the set {A, B} @@ -75,13 +79,14 @@ class VCGenerator(tracker: TransformationTracker) extends LazyLogging { } } - private def decorateWithUniversals(universalsRev: Seq[(String, TlaEx)], expr: TlaEx): TlaEx = { + @tailrec + private def decorateWithUniversals(universalsRev: Seq[(NameEx, TlaEx)], expr: TlaEx): TlaEx = { universalsRev match { case Nil => expr - case (name, set) :: tail => - decorateWithUniversals(tail, tla.forall(NameEx(name), set, expr)) + case (nameEx, set) :: tail => + decorateWithUniversals(tail, tla.forall(nameEx, set, expr).typed(BoolT1())) } } } diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/analyses/ExpansionMarker.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/analyses/ExpansionMarker.scala index 60ae55d8fe..47515794ea 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/analyses/ExpansionMarker.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/analyses/ExpansionMarker.scala @@ -4,7 +4,6 @@ import at.forsyte.apalache.tla.lir._ import at.forsyte.apalache.tla.lir.oper._ import at.forsyte.apalache.tla.lir.transformations.standard.KeraLanguagePred import at.forsyte.apalache.tla.lir.transformations.{LanguageWatchdog, TlaExTransformation, TransformationTracker} -import at.forsyte.apalache.tla.lir.UntypedPredefs._ import com.google.inject.Inject import com.typesafe.scalalogging.LazyLogging @@ -32,69 +31,82 @@ class ExpansionMarker @Inject() (tracker: TransformationTracker) extends TlaExTr // that requires an expanded set, e.g., S \\union (SUBSET T), the parameter shallExpand changes to true. def transform(shallExpand: Boolean): TlaExTransformation = tracker.trackEx { case ex @ OperEx(op @ TlaSetOper.powerset, underlyingSet) => + val tag = ex.typeTag if (shallExpand) { // Expand the set as well as the underlying set! logger.warn(s"The set $ex will be expanded. This will blow up the solver.") - OperEx(BmcOper.expand, OperEx(op, transform(true)(underlyingSet))) + OperEx(BmcOper.expand, OperEx(op, transform(true)(underlyingSet))(tag))(tag) } else { // Do not expand the set itself, but expand the underlying set! - OperEx(op, transform(true)(underlyingSet)) + OperEx(op, transform(true)(underlyingSet))(tag) } case ex @ OperEx(op @ TlaSetOper.funSet, dom, cdm) => + val tag = ex.typeTag if (shallExpand) { // Expand everything, including the function set. logger.warn(s"The set $ex will be expanded. This will blow up the solver.") - OperEx(BmcOper.expand, OperEx(op, transform(true)(dom), transform(true)(cdm))) + OperEx(BmcOper.expand, OperEx(op, transform(true)(dom), transform(true)(cdm))(tag))(tag) } else { // Only expand the domain, but keep the co-domain unexpanded, // e.g., in [SUBSET S -> SUBSET T], while SUBSET S is expanded, the co-domain SUBSET T can be left as is - OperEx(op, transform(true)(dom), transform(false)(cdm)) + OperEx(op, transform(true)(dom), transform(false)(cdm))(tag) } // simple propagation analysis that tells us what to expand - case OperEx(op @ BmcOper.`skolem`, OperEx(TlaBoolOper.exists, name, set, pred)) => + case ex @ OperEx(op @ BmcOper.`skolem`, OperEx(TlaBoolOper.exists, name, set, pred)) => // a skolemizable existential is allowed to keep its set unexpanded - OperEx(op, OperEx(TlaBoolOper.exists, name, transform(false)(set), transform(false)(pred))) + val tag = ex.typeTag + OperEx(op, OperEx(TlaBoolOper.exists, name, transform(false)(set), transform(false)(pred))(tag))(tag) - case OperEx(op @ TlaOper.chooseBounded, name, set, pred) => + case ex @ OperEx(op @ TlaOper.chooseBounded, name, set, pred) => // CHOOSE is essentially a skolemizable existential with the constraint of uniqueness - OperEx(op, name, transform(false)(set), transform(false)(pred)) + val tag = ex.typeTag + OperEx(op, name, transform(false)(set), transform(false)(pred))(tag) - case OperEx(op, name, set, pred) if op == TlaBoolOper.exists || op == TlaBoolOper.forall => + case ex @ OperEx(op, name, set, pred) if op == TlaBoolOper.exists || op == TlaBoolOper.forall => // non-skolemizable quantifiers require their sets to be expanded - OperEx(op, name, transform(true)(set), transform(false)(pred)) + val tag = ex.typeTag + OperEx(op, name, transform(true)(set), transform(false)(pred))(tag) - case OperEx(op @ TlaSetOper.in, elem, set) => + case ex @ OperEx(op @ TlaSetOper.in, elem, set) => // when checking e \in S, we can keep S unexpanded, but e should be expanded - OperEx(op, transform(true)(elem), transform(false)(set)) + val tag = ex.typeTag + OperEx(op, transform(true)(elem), transform(false)(set))(tag) - case OperEx(op, args @ _*) if op == TlaSetOper.cup || op == TlaSetOper.union => + case ex @ OperEx(op, args @ _*) if op == TlaSetOper.cup || op == TlaSetOper.union => // binary union and UNION require arena cells, hence, expand - OperEx(op, args map transform(true): _*) + val tag = ex.typeTag + OperEx(op, args map transform(true): _*)(tag) - case OperEx(op @ TlaSetOper.filter, name, set, pred) => + case ex @ OperEx(op @ TlaSetOper.filter, name, set, pred) => // For the moment, we require the set to be expanded. However, we could think of collecting constraints on the way. - OperEx(op, name, transform(true)(set), transform(false)(pred)) + val tag = ex.typeTag + OperEx(op, name, transform(true)(set), transform(false)(pred))(tag) - case OperEx(op, body, args @ _*) if op == TlaSetOper.map || op == TlaFunOper.funDef || op == TlaFunOper.recFunDef => + case ex @ OperEx(op, body, args @ _*) + if op == TlaSetOper.map || op == TlaFunOper.funDef || op == TlaFunOper.recFunDef => val tbody: TlaEx = transform(false)(body) val targs = args map transform(true) - OperEx(op, tbody +: targs: _*) + val tag = ex.typeTag + OperEx(op, tbody +: targs: _*)(tag) - case LetInEx(body, defs @ _*) => + case ex @ LetInEx(body, defs @ _*) => // at this point, we only have nullary let-in definitions def mapDef(df: TlaOperDecl) = df.copy(body = transform(shallExpand)(df.body)) - LetInEx(transform(shallExpand)(body), defs map mapDef: _*) + val tag = ex.typeTag + LetInEx(transform(shallExpand)(body), defs map mapDef: _*)(tag) - case OperEx(BmcOper.withType, expr, annot) => + case ex @ OperEx(BmcOper.withType, expr, annot) => // transform the expression, but not the annotation! See https://github.com/informalsystems/apalache/issues/292 - OperEx(BmcOper.withType, transform(shallExpand)(expr), annot) + val tag = ex.typeTag + OperEx(BmcOper.withType, transform(shallExpand)(expr), annot)(tag) - case OperEx(oper, args @ _*) => + case ex @ OperEx(oper, args @ _*) => // try to descend in the children, which may contain Boolean operations, e.g., { \E x \in S: P } - OperEx(oper, args map transform(shallExpand): _*) + val tag = ex.typeTag + OperEx(oper, args map transform(shallExpand): _*)(tag) case terminal => terminal // terminal expression, stop here diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/analyses/SkolemizationMarker.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/analyses/SkolemizationMarker.scala index fb70eba1ad..dd85a104ec 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/analyses/SkolemizationMarker.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/analyses/SkolemizationMarker.scala @@ -3,7 +3,6 @@ package at.forsyte.apalache.tla.bmcmt.analyses import at.forsyte.apalache.tla.lir._ import at.forsyte.apalache.tla.lir.oper._ import at.forsyte.apalache.tla.lir.transformations.{TlaExTransformation, TransformationTracker} -import at.forsyte.apalache.tla.lir.UntypedPredefs._ import com.google.inject.Inject import com.typesafe.scalalogging.LazyLogging import at.forsyte.apalache.tla.lir.convenience._ @@ -29,39 +28,40 @@ class SkolemizationMarker @Inject() (tracker: TransformationTracker) extends Tla } def transform: TlaExTransformation = tracker.trackEx { - case OperEx(TlaBoolOper.exists, name, set, pred) => - OperEx(BmcOper.skolem, tla.exists(name, set, transform(pred))) + case ex @ OperEx(TlaBoolOper.exists, name, set, pred) => + val tag = ex.typeTag + OperEx(BmcOper.skolem, OperEx(TlaBoolOper.exists, name, set, transform(pred))(tag))(tag) - case OperEx(TlaBoolOper.forall, name, set, pred) => + case ex @ OperEx(TlaBoolOper.forall, name, set, pred) => // it is fine to skolemize existentials under \A, as \A is translated into a conjunction - tla.forall(name, set, transform(pred)) + OperEx(TlaBoolOper.forall, name, set, transform(pred))(ex.typeTag) case op @ OperEx(TlaArithOper.ge, OperEx(TlaFiniteSetOper.cardinality, _), ValEx(TlaInt(intVal))) if intVal.isValidInt => // this constraint can be solved more efficiently than the direct computation of Cardinality - OperEx(BmcOper.constCard, op) + OperEx(BmcOper.constCard, op)(op.typeTag) case ex @ OperEx(TlaBoolOper.not, _) => ex // stop here. This should be a leaf (and rare) expression, as we are dealing with the NNF. - case OperEx(TlaBoolOper.and, args @ _*) => - tla.and(args map transform: _*) + case ex @ OperEx(TlaBoolOper.and, args @ _*) => + OperEx(TlaBoolOper.and, args map transform: _*)(ex.typeTag) - case OperEx(TlaBoolOper.or, args @ _*) => - tla.or(args map transform: _*) + case ex @ OperEx(TlaBoolOper.or, args @ _*) => + OperEx(TlaBoolOper.or, args map transform: _*)(ex.typeTag) - case OperEx(TlaControlOper.ifThenElse, cond, left, right) => + case ex @ OperEx(TlaControlOper.ifThenElse, cond, left, right) => // try to identify existentials in the both arms - tla.ite(cond, transform(left), transform(right)) + OperEx(TlaControlOper.ifThenElse, cond, transform(left), transform(right))(ex.typeTag) // We omit skolemization of the existentials in the predicate, // as the predicate is used in both the negated and non-negated forms. // Effectively, IF-THEN-ELSE requires both \E and \A forms - case LetInEx(body, defs @ _*) => + case ex @ LetInEx(body, defs @ _*) => // at this point, we only have nullary let-in definitions def mapDef(df: TlaOperDecl) = df.copy(body = transform(df.body)) - LetInEx(transform(body), defs map mapDef: _*) + LetInEx(transform(body), defs map mapDef: _*)(ex.typeTag) case ex @ OperEx(oper, args @ _*) => // bugfix for #148: do not descend into value expressions, as Skolemization of non-formulas is unsound diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/config/CheckerExceptionAdapter.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/config/CheckerExceptionAdapter.scala index ab245719dc..116bb85ae3 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/config/CheckerExceptionAdapter.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/config/CheckerExceptionAdapter.scala @@ -3,13 +3,16 @@ package at.forsyte.apalache.tla.bmcmt.config import at.forsyte.apalache.infra.{ErrorMessage, ExceptionAdapter, FailureMessage, NormalErrorMessage} import at.forsyte.apalache.tla.assignments.{AssignmentException, CoverData} import at.forsyte.apalache.tla.bmcmt._ -import at.forsyte.apalache.tla.bmcmt.types.TypeInferenceError import at.forsyte.apalache.tla.imp.SanyException import at.forsyte.apalache.tla.imp.src.SourceStore import at.forsyte.apalache.tla.lir.storage.{ChangeListener, SourceLocator} -import at.forsyte.apalache.tla.lir.{LanguagePredError, MalformedTlaError, OperEx, UID} -import at.forsyte.apalache.tla.pp.{ConfigurationError, IrrecoverablePreprocessingError, NotInKeraError, TlaInputError} -import at.forsyte.apalache.tla.typecheck.{TypingException, TypingInputException} +import at.forsyte.apalache.tla.lir.{ + LanguagePredError, MalformedTlaError, OperEx, OutdatedAnnotationsError, TypingException, UID +} +import at.forsyte.apalache.tla.pp.{ + ConfigurationError, IrrecoverablePreprocessingError, NotInKeraError, OverridingError, TlaInputError +} +import at.forsyte.apalache.tla.typecheck.TypingInputException import com.typesafe.scalalogging.LazyLogging import javax.inject.{Inject, Singleton} @@ -32,6 +35,9 @@ class CheckerExceptionAdapter @Inject() (sourceStore: SourceStore, changeListene case err: ConfigurationError => NormalErrorMessage("Configuration error (see the manual): " + err.getMessage) + case err: OverridingError => + NormalErrorMessage("Configuration error (see the manual): " + err.getMessage) + case err: TlaInputError => NormalErrorMessage("Input error (see the manual): " + err.getMessage) @@ -40,8 +46,8 @@ class CheckerExceptionAdapter @Inject() (sourceStore: SourceStore, changeListene logger.info(" [https://apalache.informal.systems/docs/apalache/principles.html#assignments]") NormalErrorMessage("Assignment error: " + err.getMessage) - case err: TypeInferenceException => - val msg = "%s\n%s".format(err.getMessage, err.errors.map(ofTypeInferenceError).mkString("\n")) + case err: OutdatedAnnotationsError => + val msg = "%s: rewriter error: %s".format(findLoc(err.causeExpr.ID), err.getMessage) NormalErrorMessage(msg) case err: LanguagePredError => @@ -112,13 +118,4 @@ class CheckerExceptionAdapter @Inject() (sourceStore: SourceStore, changeListene case None => "" } } - - def ofTypeInferenceError(e: TypeInferenceError): String = { - val locInfo = findLoc(e.origin.ID) - val exStr = e.origin match { - case OperEx(op, _*) => op.name - case ex @ _ => ex.toString() - } - "%s, %s, type error: %s".format(locInfo, exStr, e.explanation) - } } diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/config/CheckerModule.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/config/CheckerModule.scala index adc8bb9c35..ae6d4ea2b5 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/config/CheckerModule.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/config/CheckerModule.scala @@ -7,8 +7,6 @@ import at.forsyte.apalache.io.annotations.{AnnotationStoreProvider, PrettyWriter import at.forsyte.apalache.tla.assignments.passes._ import at.forsyte.apalache.tla.bmcmt.analyses._ import at.forsyte.apalache.tla.bmcmt.passes._ -import at.forsyte.apalache.tla.bmcmt.types.eager.TrivialTypeFinder -import at.forsyte.apalache.tla.bmcmt.types.{CellT, TypeFinder} import at.forsyte.apalache.tla.imp.passes.{SanyParserPass, SanyParserPassImpl} import at.forsyte.apalache.tla.lir.io.TlaWriterFactory import at.forsyte.apalache.tla.lir.storage.ChangeListener @@ -42,8 +40,6 @@ class CheckerModule extends AbstractModule { .to(classOf[FormulaHintsStoreImpl]) bind(classOf[ExprGradeStore]) .to(classOf[ExprGradeStoreImpl]) - bind(new TypeLiteral[TypeFinder[CellT]] {}) - .to(classOf[TrivialTypeFinder]) // using a trivial type finder // writers bind(classOf[TlaWriterFactory]) @@ -64,24 +60,25 @@ class CheckerModule extends AbstractModule { bind(classOf[Pass]) .annotatedWith(Names.named("InitialPass")) .to(classOf[SanyParserPass]) - // the next pass is ConfigurationPass - bind(classOf[ConfigurationPass]) - .to(classOf[ConfigurationPassImpl]) - bind(classOf[Pass]) - .annotatedWith(Names.named("AfterParser")) - .to(classOf[ConfigurationPass]) // The next pass is Snowcat that is called EtcTypeCheckerPassImpl for now. // We provide guice with a concrete implementation here, as we also use PostTypeCheckerPassImpl later in the pipeline. bind(classOf[Pass]) - .annotatedWith(Names.named("AfterConfiguration")) + .annotatedWith(Names.named("AfterParser")) .to(classOf[EtcTypeCheckerPassImpl]) + // the next pass is ConfigurationPass + bind(classOf[ConfigurationPass]) + .to(classOf[ConfigurationPassImpl]) + bind(classOf[Pass]) + .annotatedWith(Names.named("AfterTypeChecker")) + .to(classOf[ConfigurationPass]) + // the next pass is DesugarerPass bind(classOf[DesugarerPass]) .to(classOf[DesugarerPassImpl]) bind(classOf[Pass]) - .annotatedWith(Names.named("AfterTypeChecker")) + .annotatedWith(Names.named("AfterConfiguration")) .to(classOf[DesugarerPass]) // the next pass is UnrollPass bind(classOf[UnrollPass]) diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/config/TransformationTrackerProvider.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/config/TransformationTrackerProvider.scala index 08b474f801..7fcab035db 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/config/TransformationTrackerProvider.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/config/TransformationTrackerProvider.scala @@ -1,9 +1,9 @@ package at.forsyte.apalache.tla.bmcmt.config -import at.forsyte.apalache.tla.bmcmt.types.eager.TrivialTypeFinder import at.forsyte.apalache.tla.lir.storage.ChangeListener import at.forsyte.apalache.tla.lir.transformations.TransformationTracker import at.forsyte.apalache.tla.lir.transformations.impl.TrackerWithListeners +import at.forsyte.apalache.tla.typecheck.integration.TypeWatchdogTransformationListener import com.google.inject.{Inject, Provider, Singleton} /** @@ -19,14 +19,13 @@ import com.google.inject.{Inject, Provider, Singleton} * to pass a list of transformation listeners to the tracker, while the listeners are injected by Guice. * * @param changeListener a listener that records which expression was transformed into which expression - * * @author Igor Konnov */ @Singleton -class TransformationTrackerProvider @Inject() (changeListener: ChangeListener, trivialTypeFinder: TrivialTypeFinder) - extends Provider[TransformationTracker] { +class TransformationTrackerProvider @Inject() (changeListener: ChangeListener) extends Provider[TransformationTracker] { - private val tracker = TrackerWithListeners(changeListener, trivialTypeFinder) + private val tracker = + TrackerWithListeners(new TypeWatchdogTransformationListener(), changeListener) override def get(): TransformationTracker = { tracker diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/exceptions.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/exceptions.scala index 6506f7fef3..b443dcfc64 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/exceptions.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/exceptions.scala @@ -1,7 +1,6 @@ package at.forsyte.apalache.tla.bmcmt -import at.forsyte.apalache.tla.bmcmt.types.TypeInferenceError -import at.forsyte.apalache.tla.lir.{NullEx, TlaEx} +import at.forsyte.apalache.tla.lir.TlaEx /** * A generic error that occurred in the model checker. @@ -27,16 +26,6 @@ class RewriterException(message: String, causeExpr: TlaEx) extends CheckerExcept */ class TypeException(message: String, causeExpr: TlaEx) extends CheckerException(message, causeExpr) -/** - * An exception that is thrown when at least one type inference error was found. - * @param errors the list of type inference errors - */ -class TypeInferenceException(val errors: Seq[TypeInferenceError]) - extends CheckerException( - "Type inference failed, first error: %s" - .format(if (errors.nonEmpty) errors.head.explanation else "unknown"), - if (errors.nonEmpty) errors.head.origin else NullEx) - /** * This exception is thrown when QStateRewrite cannot find an applicable rule. * diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/passes/BoundedCheckerPassImpl.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/passes/BoundedCheckerPassImpl.scala index 3a2f9ee3b3..26304a68d6 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/passes/BoundedCheckerPassImpl.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/passes/BoundedCheckerPassImpl.scala @@ -10,7 +10,6 @@ import at.forsyte.apalache.tla.bmcmt.rewriter.{MetricProfilerListener, RewriterC import at.forsyte.apalache.tla.bmcmt.search._ import at.forsyte.apalache.tla.bmcmt.smt.{RecordingSolverContext, SolverConfig} import at.forsyte.apalache.tla.bmcmt.trex._ -import at.forsyte.apalache.tla.bmcmt.types.eager.TrivialTypeFinder import at.forsyte.apalache.tla.imp.src.SourceStore import at.forsyte.apalache.tla.lir.NullEx import at.forsyte.apalache.tla.lir.storage.ChangeListener @@ -96,7 +95,6 @@ class BoundedCheckerPassImpl @Inject() (val options: PassOptions, hintsStore: Fo solverConfig: SolverConfig): Boolean = { val solverContext: RecordingSolverContext = RecordingSolverContext.createZ3(None, solverConfig) - val typeFinder = new TrivialTypeFinder val metricProfilerListener = if (solverConfig.profile) { logger.info("Profiling data will be written to profile.csv") @@ -106,7 +104,7 @@ class BoundedCheckerPassImpl @Inject() (val options: PassOptions, hintsStore: Fo } val rewriter: SymbStateRewriterImpl = - new SymbStateRewriterImpl(solverContext, typeFinder, exprGradeStore, metricProfilerListener) + new SymbStateRewriterImpl(solverContext, exprGradeStore, metricProfilerListener) rewriter.formulaHintsStore = hintsStore rewriter.config = RewriterConfig(tuning) @@ -132,8 +130,7 @@ class BoundedCheckerPassImpl @Inject() (val options: PassOptions, hintsStore: Fo logger.warn("SMT profiling is enabled, but offline SMT is used. No profiling data will be written.") } - val typeFinder = new TrivialTypeFinder - val rewriter: SymbStateRewriterImpl = new SymbStateRewriterImpl(solverContext, typeFinder, exprGradeStore) + val rewriter: SymbStateRewriterImpl = new SymbStateRewriterImpl(solverContext, exprGradeStore) rewriter.formulaHintsStore = hintsStore rewriter.config = RewriterConfig(tuning) diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/passes/PostTypeCheckerPassImpl.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/passes/PostTypeCheckerPassImpl.scala index ea1a077160..673a799a54 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/passes/PostTypeCheckerPassImpl.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/passes/PostTypeCheckerPassImpl.scala @@ -27,5 +27,8 @@ class PostTypeCheckerPassImpl @Inject() (options: PassOptions, sourceStore: Sour extends EtcTypeCheckerPassImpl(options, sourceStore, changeListener, tracker, annotationStore, nextPass) with LazyLogging { + // in the post-checking, polytypes are not allowed, as the model checker will not be able to handle them + override def inferPoly: Boolean = false + override def name: String = "PostTypeCheckerSnowcat" } diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rewriter/SymbStateRewriterSnapshot.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rewriter/SymbStateRewriterSnapshot.scala index 7b402432de..f01a626e7f 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rewriter/SymbStateRewriterSnapshot.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rewriter/SymbStateRewriterSnapshot.scala @@ -2,14 +2,12 @@ package at.forsyte.apalache.tla.bmcmt.rewriter import at.forsyte.apalache.tla.bmcmt.analyses.ExprGrade import at.forsyte.apalache.tla.bmcmt.caches.{AbstractCacheSnapshot, SimpleCacheSnapshot} -import at.forsyte.apalache.tla.bmcmt.types.eager.TrivialTypeSnapshot import at.forsyte.apalache.tla.bmcmt.{Arena, ArenaCell} import at.forsyte.apalache.tla.lir.TlaEx import scala.collection.immutable.SortedSet -class SymbStateRewriterSnapshot(val typeFinderSnapshot: TrivialTypeSnapshot, - val intValueCacheSnapshot: AbstractCacheSnapshot[Arena, BigInt, ArenaCell], +class SymbStateRewriterSnapshot(val intValueCacheSnapshot: AbstractCacheSnapshot[Arena, BigInt, ArenaCell], val intRangeCacheSnapshot: AbstractCacheSnapshot[Arena, (Int, Int), ArenaCell], val strValueCacheSnapshot: AbstractCacheSnapshot[Arena, String, ArenaCell], val recordDomainCache: AbstractCacheSnapshot[Arena, (SortedSet[String], SortedSet[String]), ArenaCell], diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/AndRule.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/AndRule.scala index c5d1c57938..4080ab0cb6 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/AndRule.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/AndRule.scala @@ -4,10 +4,10 @@ import at.forsyte.apalache.tla.bmcmt._ import at.forsyte.apalache.tla.bmcmt.analyses.FormulaHintsStore import at.forsyte.apalache.tla.bmcmt.rewriter.ConstSimplifierForSmt import at.forsyte.apalache.tla.bmcmt.types.BoolT +import at.forsyte.apalache.tla.lir.TypedPredefs._ import at.forsyte.apalache.tla.lir.convenience.tla -import at.forsyte.apalache.tla.lir.UntypedPredefs._ import at.forsyte.apalache.tla.lir.oper.TlaBoolOper -import at.forsyte.apalache.tla.lir.{NameEx, OperEx, TlaEx, ValEx} +import at.forsyte.apalache.tla.lir.{BoolT1, OperEx, TlaEx, ValEx} /** * Implements the rule for conjunction. Similar to TLC, we short-circuit A /\ B as IF A THEN B ELSE FALSE. @@ -19,6 +19,7 @@ import at.forsyte.apalache.tla.lir.{NameEx, OperEx, TlaEx, ValEx} */ class AndRule(rewriter: SymbStateRewriter) extends RewritingRule { private val simplifier = new ConstSimplifierForSmt() + private val boolTypes = Map("b" -> BoolT1()) override def isApplicable(symbState: SymbState): Boolean = { symbState.ex match { @@ -38,8 +39,13 @@ class AndRule(rewriter: SymbStateRewriter) extends RewritingRule { // use short-circuiting on state-level expressions (like in TLC) def toIte(es: Seq[TlaEx]): TlaEx = { es match { - case Seq(last) => last - case hd +: tail => tla.ite(hd, toIte(tail), state.arena.cellFalse().toNameEx) + case Seq(last) => + last + + case hd +: tail => + tla + .ite(hd, toIte(tail), state.arena.cellFalse().toNameEx ? "b") + .typed(boolTypes, "b") } } @@ -71,13 +77,15 @@ class AndRule(rewriter: SymbStateRewriter) extends RewritingRule { // simply translate to a conjunction var nextState = state.updateArena(_.appendCell(BoolT())) val pred = nextState.arena.topCell.toNameEx + def mapArg(argEx: TlaEx): TlaEx = { nextState = rewriter.rewriteUntilDone(nextState.setRex(argEx)) nextState.ex } val rewrittenArgs = args map mapArg - rewriter.solverContext.assertGroundExpr(tla.eql(pred, tla.and(rewrittenArgs: _*))) + val eq = tla.eql(pred ? "b", tla.and(rewrittenArgs: _*) ? "b").typed(boolTypes, "b") + rewriter.solverContext.assertGroundExpr(eq) nextState.setRex(pred) } rewriter.rewriteUntilDone(newState) @@ -119,7 +127,10 @@ class AndRule(rewriter: SymbStateRewriter) extends RewritingRule { // propagate var nextState = tailState.updateArena(_.appendCell(BoolT())) val pred = nextState.asCell.toNameEx - rewriter.solverContext.assertGroundExpr(tla.equiv(pred, tla.and(headCell.toNameEx, tailState.ex))) + val eq = tla + .equiv(pred ? "b", tla.and(headCell.toNameEx ? "b", tailState.ex) ? "b") + .typed(boolTypes, "b") + rewriter.solverContext.assertGroundExpr(eq) nextState.setRex(pred) } } diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/ChooseRule.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/ChooseRule.scala index ab69c53717..2d9362ee42 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/ChooseRule.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/ChooseRule.scala @@ -4,9 +4,9 @@ import at.forsyte.apalache.tla.bmcmt._ import at.forsyte.apalache.tla.bmcmt.rules.aux.{CherryPick, DefaultValueFactory, OracleHelper} import at.forsyte.apalache.tla.bmcmt.types.FinSetT import at.forsyte.apalache.tla.lir.convenience.tla -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.lir.TypedPredefs._ import at.forsyte.apalache.tla.lir.oper.TlaOper -import at.forsyte.apalache.tla.lir.{OperEx, TlaEx} +import at.forsyte.apalache.tla.lir.{BoolT1, OperEx, SetT1, TlaEx, TypingException} /** *

Rewriting rule for CHOOSE. Similar to TLC, we implement a non-deterministic choice. @@ -36,8 +36,11 @@ class ChooseRule(rewriter: SymbStateRewriter) extends RewritingRule { // This is a general encoding, handling both happy and unhappy scenarios, // that is, when CHOOSE is defined on its arguments and not, respectively. def solverAssert = rewriter.solverContext.assertGroundExpr _ + // compute set comprehension and then pick an element from it - val filterEx = tla.filter(varName, set, pred) + val filterEx = tla + .filter(varName, set, pred) + .typed(set.typeTag.asTlaType1()) var nextState = rewriter.rewriteUntilDone(state.setRex(filterEx)) // pick an arbitrary witness val setCell = nextState.asCell @@ -94,7 +97,15 @@ class ChooseRule(rewriter: SymbStateRewriter) extends RewritingRule { val trueEx = nextState.arena.cellTrue().toNameEx // pick only the elements that belong to the set - val elemsIn = elems map { e => tla.in(e.toNameEx, setCell.toNameEx).untyped() } + val elemType = setCell.cellType.toTlaType1 match { + case SetT1(tt) => tt + case tt => throw new TypingException("Expected a set, found: " + tt) + } + val elemsIn = elems map { e => + tla + .in(e.toNameEx ? "e", setCell.toNameEx ? "s") + .typed(Map("e" -> elemType, "s" -> SetT1(elemType), "b" -> BoolT1()), "b") + } solverAssert(oracle.caseAssertions(nextState, elemsIn)) nextState = pickRule.pickByOracle(nextState, oracle, elems, trueEx) val witness = nextState.asCell diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/EqRule.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/EqRule.scala index 53df0e394e..caa7534df3 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/EqRule.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/EqRule.scala @@ -27,8 +27,6 @@ class EqRule(rewriter: SymbStateRewriter) extends RewritingRule { state.setRex(state.arena.cellTrue().toNameEx) case OperEx(TlaOper.eq, lhs, rhs) => - // Rewrite the both arguments in Cell theory. Although by doing so, - // we may introduce redundant cells, we don't have to think about types. var newState = rewriter.rewriteUntilDone(state.setRex(lhs)) val leftCell = newState.asCell newState = rewriter.rewriteUntilDone(newState.setRex(rhs)) diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/FunAppRule.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/FunAppRule.scala index 2a682acbee..5a77001874 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/FunAppRule.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/FunAppRule.scala @@ -4,8 +4,8 @@ import at.forsyte.apalache.tla.bmcmt._ import at.forsyte.apalache.tla.bmcmt.rewriter.ConstSimplifierForSmt import at.forsyte.apalache.tla.bmcmt.rules.aux.{CherryPick, DefaultValueFactory} import at.forsyte.apalache.tla.bmcmt.types._ -import at.forsyte.apalache.tla.lir.convenience._ import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.lir.convenience._ import at.forsyte.apalache.tla.lir.oper.TlaFunOper import at.forsyte.apalache.tla.lir.values.{TlaInt, TlaStr} import at.forsyte.apalache.tla.lir.{OperEx, TlaEx, ValEx} @@ -29,17 +29,19 @@ class FunAppRule(rewriter: SymbStateRewriter) extends RewritingRule { override def apply(state: SymbState): SymbState = { state.ex match { - case OperEx(TlaFunOper.app, funEx, argEx) => + case ex @ OperEx(TlaFunOper.app, funEx, argEx) => // SE-FUN-APP1 val funState = rewriter.rewriteUntilDone(state.setRex(funEx)) val funCell = funState.asCell + // we use funCell.cellType, not funEx.typeTag, because funEx can be the result of the rewriter funCell.cellType match { case TupleT(_) => applyTuple(funState, funCell, funEx, argEx) case RecordT(_) => - applyRecord(funState, funCell, funEx, argEx) + val resultT = CellT.fromTypeTag(ex.typeTag) + applyRecord(funState, funCell, funEx, argEx, resultT) case SeqT(_) => applySeq(funState, funCell, argEx) @@ -53,7 +55,8 @@ class FunAppRule(rewriter: SymbStateRewriter) extends RewritingRule { } } - private def applyRecord(state: SymbState, recordCell: ArenaCell, recEx: TlaEx, argEx: TlaEx): SymbState = { + private def applyRecord(state: SymbState, recordCell: ArenaCell, recEx: TlaEx, argEx: TlaEx, + resultT: CellT): SymbState = { val key = argEx match { case ValEx(TlaStr(k)) => k case _ => throw new RewriterException(s"Accessing a record $recEx with a non-constant key $argEx", argEx) @@ -67,10 +70,8 @@ class FunAppRule(rewriter: SymbStateRewriter) extends RewritingRule { if (index >= 0 && index < elems.length) { state.setRex(elems(index).toNameEx) } else { - // This case should have been caught by type inference. Throw an exception immediately. - val msg = - s"Accessing record $recEx of type ${recordCell.cellType} with the field $argEx. Type inference should have caught this." - throw new IllegalArgumentException(msg) + // The key does not belong to the record. This can happen as records of different domains can be unified + defaultValueFactory.makeUpValue(state, resultT) } } diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/FunCtorRule.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/FunCtorRule.scala index e6cff5e85d..fdb64a0906 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/FunCtorRule.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/FunCtorRule.scala @@ -2,10 +2,10 @@ package at.forsyte.apalache.tla.bmcmt.rules import at.forsyte.apalache.tla.bmcmt._ import at.forsyte.apalache.tla.bmcmt.types._ +import at.forsyte.apalache.tla.lir.TypedPredefs._ import at.forsyte.apalache.tla.lir.convenience.tla -import at.forsyte.apalache.tla.lir.UntypedPredefs._ import at.forsyte.apalache.tla.lir.oper.TlaFunOper -import at.forsyte.apalache.tla.lir.{NameEx, OperEx, TlaEx} +import at.forsyte.apalache.tla.lir._ /** * The new implementation of a function constructor that encodes a function f = [x \in S |-> e] the classical way: @@ -23,9 +23,10 @@ class FunCtorRule(rewriter: SymbStateRewriter) extends RewritingRule { override def apply(state: SymbState): SymbState = { state.ex match { - case OperEx(TlaFunOper.funDef, mapEx, NameEx(varName), setEx) => + case ex @ OperEx(TlaFunOper.funDef, mapEx, NameEx(varName), setEx) => // note that we only have a single-argument case here, as Desugarer collapses multiple arguments into a tuple - rewriteFunCtor(state, mapEx, varName, setEx) + val funT = TlaType1.fromTypeTag(ex.typeTag).asInstanceOf[FunT1] + rewriteFunCtor(state, funT, mapEx, varName, setEx) case _ => throw new RewriterException( @@ -34,24 +35,20 @@ class FunCtorRule(rewriter: SymbStateRewriter) extends RewritingRule { } } - private def rewriteFunCtor(state: SymbState, mapEx: TlaEx, varName: String, setEx: TlaEx) = { + private def rewriteFunCtor(state: SymbState, funT1: FunT1, mapEx: TlaEx, varName: String, setEx: TlaEx) = { // rewrite the set expression into a memory cell var nextState = rewriter.rewriteUntilDone(state.setRex(setEx)) val domainCell = nextState.asCell - val elemT = domainCell.cellType match { - case FinSetT(et) => et - case t @ _ => throw new RewriterException("Expected a finite set, found: " + t, state.ex) - } + val funT = CellT.fromType1(funT1) + val elemT = CellT.fromType1(funT1.arg) + val resultT = CellT.fromType1(funT1.res) val domainCells = nextState.arena.getHas(domainCell) // find the type of the target expression and of the target set - val resultT = rewriter.typeFinder.computeRec(mapEx) - val funT = - rewriter.typeFinder - .compute(state.ex, resultT, elemT, domainCell.cellType) - .asInstanceOf[FunT] // unfold the set and map every potential element to a cell // actually, instead of mapping every cell to e, we map it to <> to construct the relation - val pairEx = tla.tuple(tla.name(varName), mapEx) + val pairEx = tla + .tuple(tla.name(varName).typed(funT1.arg), mapEx) + .typed(TupT1(funT1.arg, funT1.res)) val (afterMapState: SymbState, relationCells: Seq[ArenaCell]) = mapCells(nextState, pairEx, varName, setEx, domainCells) @@ -71,9 +68,9 @@ class FunCtorRule(rewriter: SymbStateRewriter) extends RewritingRule { // associate a value of the uninterpreted function with a cell def addCellCons(domElem: ArenaCell, relElem: ArenaCell): Unit = { - val inDomain = tla.in(domElem.toNameEx, domainCell.toNameEx) - val inRelation = tla.in(relElem.toNameEx, relation.toNameEx) - val iff = tla.equiv(inDomain, inRelation) + val inDomain = tla.in(domElem.toNameEx, domainCell.toNameEx).typed(BoolT1()) + val inRelation = tla.in(relElem.toNameEx, relation.toNameEx).typed(BoolT1()) + val iff = tla.equiv(inDomain, inRelation).typed(BoolT1()) rewriter.solverContext.assertGroundExpr(iff) } diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/FunExceptRule.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/FunExceptRule.scala index 05d9aeaa0d..90a0a4974e 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/FunExceptRule.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/FunExceptRule.scala @@ -7,7 +7,9 @@ import at.forsyte.apalache.tla.lir.convenience._ import at.forsyte.apalache.tla.lir.oper.TlaFunOper import at.forsyte.apalache.tla.lir.values.{TlaInt, TlaStr} import at.forsyte.apalache.tla.lir.{NameEx, OperEx, TlaEx, ValEx} -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.lir.TypedPredefs._ +import at.forsyte.apalache.tla.lir.TlaType1 +import at.forsyte.apalache.tla.lir.{FunT1, RecT1, TupT1, BoolT1, SetT1} /** * Rewriting EXCEPT for functions, tuples, and records. @@ -28,31 +30,24 @@ class FunExceptRule(rewriter: SymbStateRewriter) extends RewritingRule { override def apply(state: SymbState): SymbState = { state.ex match { - case OperEx(TlaFunOper.except, args @ _*) => - val funEx = args.head - val indexEs = args.tail.zipWithIndex.filter(_._2 % 2 == 0).map(_._1) - // first, unpack singleton tuples in indices, see the comment to the method - val unpackedIndices = unpackSingletonIndices(indexEs) - val valEs = args.tail.zipWithIndex.filter(_._2 % 2 == 1).map(_._1) - assert(indexEs.size == valEs.size) - - // second, rewrite all the arguments - val (groundState: SymbState, groundArgs: Seq[TlaEx]) = - rewriter.rewriteSeqUntilDone(state, funEx +: (unpackedIndices ++ valEs)) - val funCell = groundState.arena.findCellByNameEx(groundArgs.head) - val indexCells = groundArgs - .slice(1, 1 + unpackedIndices.size) - .map(groundState.arena.findCellByNameEx) - val valueCells = groundArgs - .slice(1 + unpackedIndices.size, 1 + unpackedIndices.size + valEs.size) - .map(groundState.arena.findCellByNameEx) - funCell.cellType match { - case FunT(_, _) => rewriteFun(groundState, funCell, indexCells, valueCells) - case rt @ RecordT(_) => rewriteRec(groundState, funCell, rt, unpackedIndices, valueCells) - case tt @ TupleT(_) => rewriteTuple(groundState, funCell, tt, unpackedIndices, valueCells) + case ex @ OperEx(TlaFunOper.except, funEx, OperEx(TlaFunOper.tuple, indexEx), valueEx) => + // Desugarer takes care of general EXCEPT and provides us with the simple form + // rewrite the arguments + var nextState = state + nextState = rewriter.rewriteUntilDone(nextState.setRex(funEx)) + val funCell = nextState.asCell + nextState = rewriter.rewriteUntilDone(nextState.setRex(indexEx)) + val indexCell = nextState.asCell + nextState = rewriter.rewriteUntilDone(nextState.setRex(valueEx)) + val valueCell = nextState.asCell + val funT = TlaType1.fromTypeTag(ex.typeTag) + // delegate to the code that knows how to deal with the specific type + funT match { + case ft @ FunT1(_, _) => rewriteFun(nextState, funCell, ft, indexCell, valueCell) + case rt @ RecT1(_) => rewriteRec(nextState, funCell, rt, indexEx, valueCell) + case tt @ TupT1(_ @_*) => rewriteTuple(nextState, funCell, tt, indexEx, valueCell) case _ => - throw new NotImplementedError( - s"EXCEPT is not implemented for ${funCell.cellType}. Write a feature request.") + throw new NotImplementedError(s"EXCEPT is not implemented for $funT. Write a feature request.") } case _ => @@ -60,34 +55,46 @@ class FunExceptRule(rewriter: SymbStateRewriter) extends RewritingRule { } } - def rewriteFun(state: SymbState, funCell: ArenaCell, indexCells: Seq[ArenaCell], - valueCells: Seq[ArenaCell]): SymbState = { + def rewriteFun(state: SymbState, funCell: ArenaCell, funT: FunT1, indexCell: ArenaCell, + valueCell: ArenaCell): SymbState = { // rewrite tuples <> to cells - val updatePairs = indexCells.zip(valueCells) // ![j_i] = e_i - def mkPair(indexCell: ArenaCell, resCell: ArenaCell): TlaEx = tla.tuple(indexCell.toNameEx, resCell.toNameEx) + def mkPair(indexCell: ArenaCell, resCell: ArenaCell): TlaEx = { + tla + .tuple(indexCell.toNameEx, resCell.toNameEx) + .typed(TupT1(funT.arg, funT.res)) + } - val (stateAfterTuples, updateTuples) = - rewriter.rewriteSeqUntilDone(state, updatePairs map (mkPair _).tupled) - val updateTuplesAsCells = updateTuples.map(stateAfterTuples.arena.findCellByNameEx(_)) + var nextState = rewriter.rewriteUntilDone(state.setRex(mkPair(indexCell, valueCell))) + val newPairCell = nextState.asCell // get the function relation from the arena - var nextState = stateAfterTuples - val relation = state.arena.getCdm(funCell) + val relation = nextState.arena.getCdm(funCell) val relationCells = nextState.arena.getHas(relation) nextState = nextState.updateArena(_.appendCell(relation.cellType)) val resultRelation = nextState.arena.topCell // introduce a new function relation that is organized as follows: - // [ p \in f_rel |-> IF p[1] = j_1 THEN <> ELSE ... ELSE p ] - def eachRelationPair(p: ArenaCell): ArenaCell = { - val ite = toIte(nextState.arena, p, indexCells, updateTuplesAsCells) + // [ p \in f_rel |-> IF p[1] = i THEN <> ELSE p ] + def eachRelationPair(pair: ArenaCell): ArenaCell = { + val tupT = TupT1(funT.arg, funT.res) + val types = Map("p" -> tupT, "i" -> funT.arg, "b" -> BoolT1(), "r" -> SetT1(tupT)) + // Since the expression goes to the solver, we don't care about types. + val pairIndex = nextState.arena.getHas(pair).head // this is pair[1] + val ite = tla + .ite(tla.eql(pairIndex.toNameEx ? "p", indexCell.toNameEx ? "i") ? "b", newPairCell.toNameEx ? "p", + pair.toNameEx ? "p") + .typed(Map("p" -> tupT, "i" -> funT.arg, "b" -> BoolT1()), "p") + nextState = rewriter.rewriteUntilDone(nextState.setRex(ite)) val updatedCell = nextState.asCell // add the new cell to the arena immediately, as we are going to use the IN predicates nextState = nextState.updateArena(_.appendHas(resultRelation, updatedCell)) - // the new cell belongs to the new relation iff the old cell belongs to the old relation - solverAssert(tla.equiv(tla.in(p.toNameEx, relation.toNameEx), - tla.in(updatedCell.toNameEx, resultRelation.toNameEx))) + // The new cell belongs to the new relation iff the old cell belongs to the old relation. + val assertion = tla + .equiv(tla.in(pair.toNameEx ? "p", relation.toNameEx ? "r") ? "b", + tla.in(updatedCell.toNameEx ? "p", resultRelation.toNameEx ? "r") ? "b") + .typed(types, "b") + solverAssert(assertion) updatedCell } @@ -97,9 +104,7 @@ class FunExceptRule(rewriter: SymbStateRewriter) extends RewritingRule { // cache equality constraints between the indices and the indices in the function relation def cacheEqForPair(p: ArenaCell): Unit = { val pairIndex = nextState.arena.getHas(p).head - for (updateIndex <- indexCells) { - nextState = cacheEq(nextState, pairIndex, updateIndex) - } + nextState = cacheEq(nextState, pairIndex, indexCell) } // cache all equalities @@ -114,93 +119,62 @@ class FunExceptRule(rewriter: SymbStateRewriter) extends RewritingRule { .setRex(newFunCell.toNameEx) } - def rewriteRec(state: SymbState, recCell: ArenaCell, recType: RecordT, indexEs: Seq[TlaEx], - valueCells: Seq[ArenaCell]): SymbState = { - def indexToStr: TlaEx => String = { + def rewriteRec(state: SymbState, oldRecord: ArenaCell, recType: RecT1, indexEx: TlaEx, + newValue: ArenaCell): SymbState = { + + val keyToUpdate = indexEx match { case ValEx(TlaStr(key)) => key case ex => throw new RewriterException("Expected a string when updating a record, found: " + ex, ex) } - val updatedKeys = indexEs map indexToStr - val unchangedKeys = recType.fields.keySet.diff(Set(updatedKeys: _*)) - // create a new record - def mkUnchanged(key: String): (TlaEx, TlaEx) = { - (tla.str(key), tla.appFun(recCell.toNameEx, tla.str(key))) + var nextState = state.updateArena(_.appendCell(oldRecord.cellType)) + val newRecord = nextState.arena.topCell + val domain = nextState.arena.getDom(oldRecord) + // copy over the domain, as it does not change + nextState = nextState.updateArena(_.setDom(newRecord, domain)) + + // add the key-value pairs of the old record but update the key that was requested to be updated + def updateOrKeep(key: String, oldValue: ArenaCell): ArenaCell = { + if (key == keyToUpdate) { + newValue + } else { + oldValue + } } - def flattenPairs(list: Seq[TlaEx], pair: (TlaEx, TlaEx)): Seq[TlaEx] = { - pair._1 +: pair._2 +: list + for ((key, cell) <- recType.fieldTypes.keys.zip(nextState.arena.getHas(oldRecord))) { + nextState = nextState.updateArena(_.appendHasNoSmt(newRecord, updateOrKeep(key, cell))) } - // [ [k1, v1], [k2, v2], ... ] - val updatedPairs: Seq[(TlaEx, TlaEx)] = indexEs.zip(valueCells.map(_.toNameEx)) - val unchangedPairs: Seq[(TlaEx, TlaEx)] = unchangedKeys.toList.map(mkUnchanged) - val newRecEx = - OperEx(TlaFunOper.enum, (updatedPairs ++ unchangedPairs).reverse.foldLeft(Seq[TlaEx]())(flattenPairs): _*) - rewriter.rewriteUntilDone(state.setRex(newRecEx)) // let the rewriter handle this + rewriter.rewriteUntilDone(nextState.setRex(newRecord.toNameEx)) } - def rewriteTuple(state: SymbState, tupleCell: ArenaCell, tupleT: TupleT, indexEs: Seq[TlaEx], - valueCells: Seq[ArenaCell]): SymbState = { - def indexToInt: TlaEx => Int = { + def rewriteTuple(state: SymbState, oldTuple: ArenaCell, tupleT: TupT1, indexEx: TlaEx, + newValue: ArenaCell): SymbState = { + + val indexToUpdate = indexEx match { case ValEx(TlaInt(index)) => index.toInt case ex => throw new RewriterException("Expected a number when updating a tuple, found: " + ex, ex) } - val updatedIndices = indexEs map indexToInt - val updateMap = Map(updatedIndices.zip(valueCells): _*) + // create a new tuple + var nextState = state.updateArena(_.appendCell(oldTuple.cellType)) + val newTuple = nextState.arena.topCell - def updateOrKeep(i: Int): TlaEx = { - if (updateMap.contains(i)) { - updateMap(i).toNameEx + // add the indices of old tuple but update the index that was requested to be updated + def updateOrKeep(index: Int, oldValue: ArenaCell): ArenaCell = { + if (index == indexToUpdate) { + newValue } else { - tla.appFun(tupleCell.toNameEx, tla.int(i)) + oldValue } } - val tupleSize = tupleT.args.size - val newTuple = tla.tuple(1.to(tupleSize) map updateOrKeep: _*) - rewriter.rewriteUntilDone(state.setRex(newTuple)) // let the rewriter handle this - } - - def toIte(arena: Arena, pair: ArenaCell, indexCells: Seq[ArenaCell], updatePairs: Seq[ArenaCell]): TlaEx = { - val pairIndex = arena.getHas(pair).head // the first element of the pair is the index - updatePairs match { - case Seq() => pair.toNameEx // ... ELSE p - case newPair +: _ => - val updateIndex = indexCells.head // IF p[1] = i_j - tla.ite(tla.eql(pairIndex.toNameEx, updateIndex.toNameEx), newPair.toNameEx, - toIte(arena, pair, indexCells.tail, updatePairs.tail)) + for ((cell, index0based) <- nextState.arena.getHas(oldTuple).zipWithIndex) { + nextState = nextState.updateArena(_.appendHasNoSmt(newTuple, updateOrKeep(index0based + 1, cell))) } - } - - def addEqualities(state: SymbState, lhs: ArenaCell, rhs: ArenaCell): SymbState = { - rewriter.lazyEq.cacheOneEqConstraint(state, lhs, rhs) - } - // This is an important step. As we receive expressions from SANY, every index argument to EXCEPT - // is always a tuple]. For instance, the expression [f EXCEPT ![1] = 2] will be represented - // as OperEx(TlaFunOper.except, f, <<1>>, 2). Hence, we explicitly unpack singleton tuples. - // As for non-singleton tuples, they should be preprocessed. - private def unpackSingletonIndices(args: Seq[TlaEx]): Seq[TlaEx] = { - def unpack(e: TlaEx) = e match { - case OperEx(TlaFunOper.tuple, arg) => - arg // unpack - case OperEx(TlaFunOper.tuple, _*) => - throw new InternalCheckerError("TLA importer failed to preprocess a chained EXCEPT: " + e, e) - case _ => - // complain - throw new RewriterException("Expected a tuple as a function index, found: " + e, e) - } - - args map unpack - } - - private def checkType(cellType: CellT): Unit = { - cellType match { - case FunT(_, _) => () // o.k. - case _ => throw new NotImplementedError(s"EXCEPT is not implemented for $cellType. Write a feature request.") - } + rewriter.rewriteUntilDone(nextState.setRex(newTuple.toNameEx)) } } diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/IfThenElseRule.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/IfThenElseRule.scala index 48cee54677..cf4e0e4716 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/IfThenElseRule.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/IfThenElseRule.scala @@ -7,7 +7,7 @@ import at.forsyte.apalache.tla.bmcmt.types._ import at.forsyte.apalache.tla.lir.convenience._ import at.forsyte.apalache.tla.lir.UntypedPredefs._ import at.forsyte.apalache.tla.lir.oper.TlaControlOper -import at.forsyte.apalache.tla.lir.{OperEx, TlaEx} +import at.forsyte.apalache.tla.lir.{OperEx, TlaEx, TlaType1} /** * Rewriting rule for IF A THEN B ELSE C. @@ -27,7 +27,7 @@ class IfThenElseRule(rewriter: SymbStateRewriter) extends RewritingRule { override def apply(state: SymbState): SymbState = { state.ex match { - case OperEx(TlaControlOper.ifThenElse, predEx, thenEx, elseEx) => + case ex @ OperEx(TlaControlOper.ifThenElse, predEx, thenEx, elseEx) => var nextState = rewriter.rewriteUntilDone(state.setRex(predEx)) val predCell = nextState.asCell // Some rules immediately return TRUE or FALSE. In combination with assignments, this may lead to rewriting errors. @@ -43,7 +43,7 @@ class IfThenElseRule(rewriter: SymbStateRewriter) extends RewritingRule { nextState = rewriter.rewriteUntilDone(nextState.setRex(elseEx)) val elseCell = nextState.asCell - val resultType = rewriter.typeFinder.compute(state.ex, BoolT(), thenCell.cellType, elseCell.cellType) + val resultType = CellT.fromTypeTag(ex.typeTag) resultType match { // basic types, we can use SMT equality case BoolT() | IntT() | ConstT() => iteBasic(nextState, resultType, predCell.toNameEx, thenCell, elseCell) diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/OrRule.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/OrRule.scala index 1f0496c63a..4da4413cc1 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/OrRule.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/OrRule.scala @@ -3,10 +3,10 @@ package at.forsyte.apalache.tla.bmcmt.rules import at.forsyte.apalache.tla.bmcmt._ import at.forsyte.apalache.tla.bmcmt.rewriter.ConstSimplifierForSmt import at.forsyte.apalache.tla.bmcmt.types.BoolT -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.lir.TypedPredefs._ import at.forsyte.apalache.tla.lir.convenience.tla import at.forsyte.apalache.tla.lir.oper.TlaBoolOper -import at.forsyte.apalache.tla.lir.{OperEx, TlaEx, ValEx} +import at.forsyte.apalache.tla.lir.{BoolT1, OperEx, TlaEx, ValEx} /** * For state-level expressions, we express A \/ B as IF A THEN TRUE ELSE B. @@ -16,6 +16,8 @@ import at.forsyte.apalache.tla.lir.{OperEx, TlaEx, ValEx} * @author Igor Konnov */ class OrRule(rewriter: SymbStateRewriter) extends RewritingRule { + private val boolTypes = Map("b" -> BoolT1()) + override def isApplicable(symbState: SymbState): Boolean = { symbState.ex match { case OperEx(TlaBoolOper.or, _*) => true @@ -34,8 +36,13 @@ class OrRule(rewriter: SymbStateRewriter) extends RewritingRule { // use short-circuiting on state-level expressions (like in TLC) def toIte(es: Seq[TlaEx]): TlaEx = { es match { - case Seq(last) => last - case hd +: tail => tla.ite(hd, state.arena.cellTrue().toNameEx, toIte(tail)) + case Seq(last) => + last + + case hd +: tail => + tla + .ite(hd ? "b", state.arena.cellTrue().toNameEx ? "b", toIte(tail)) + .typed(boolTypes, "b") } } @@ -54,7 +61,10 @@ class OrRule(rewriter: SymbStateRewriter) extends RewritingRule { } val rewrittenArgs = args map mapArg - rewriter.solverContext.assertGroundExpr(tla.eql(pred, tla.or(rewrittenArgs: _*))) + val eq = tla + .eql(pred ? "b", tla.or(rewrittenArgs: _*) ? "b") + .typed(boolTypes, "b") + rewriter.solverContext.assertGroundExpr(eq) nextState.setRex(pred) } diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/PowSetCtorRule.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/PowSetCtorRule.scala index 5ec3c192d8..75a02d7a87 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/PowSetCtorRule.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/PowSetCtorRule.scala @@ -22,13 +22,13 @@ class PowSetCtorRule(rewriter: SymbStateRewriter) extends RewritingRule { state.ex match { case OperEx(TlaSetOper.powerset, setEx) => // switch to cell theory - val nextState = rewriter.rewriteUntilDone(state.setRex(setEx)) + var nextState = rewriter.rewriteUntilDone(state.setRex(setEx)) val dom = nextState.arena.findCellByNameEx(nextState.ex) - val arena = nextState.arena.appendCell(PowSetT(dom.cellType)) - val powSetCell = arena.topCell - val newArena = arena.setDom(powSetCell, dom) - state.setArena(newArena).setRex(powSetCell.toNameEx) + nextState = nextState.updateArena(_.appendCell(PowSetT(dom.cellType))) + val powSetCell = nextState.arena.topCell + nextState = nextState.updateArena(_.setDom(powSetCell, dom)) + nextState.setRex(powSetCell.toNameEx) case _ => throw new RewriterException("%s is not applicable".format(getClass.getSimpleName), state.ex) diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/QuantRule.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/QuantRule.scala index eb72192f0c..18b4c68522 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/QuantRule.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/QuantRule.scala @@ -50,11 +50,9 @@ class QuantRule(rewriter: SymbStateRewriter) extends RewritingRule with LazyLogg skolemExistsInSet(setState, boundVar, predEx, set) case PowSetT(FinSetT(_)) => - () skolemExistsByPick(setState, boundVar, predEx, set) case FinFunSetT(_, _) => - () skolemExistsByPick(setState, boundVar, predEx, set) case tp => diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/RecCtorRule.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/RecCtorRule.scala index 4e452c623e..a14dd5f464 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/RecCtorRule.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/RecCtorRule.scala @@ -17,10 +17,6 @@ import scala.collection.immutable.SortedSet * Internally, a record is stored as a tuple, * where an index i corresponds to the ith key in the sorted set of record keys. * - * Note that one can extend a record with a type annotation, e.g., - * the expression [a |-> 1] <: [a |-> Int, b |-> BOOLEAN] introduces a record r of two fields (a: Int and b: BOOLEAN). - * The value r.a is defined as 1, whereas r.b is arbitrary. - * * @author Igor Konnov */ class RecCtorRule(rewriter: SymbStateRewriter) extends RewritingRule { @@ -35,7 +31,7 @@ class RecCtorRule(rewriter: SymbStateRewriter) extends RewritingRule { override def apply(state: SymbState): SymbState = { state.ex match { - case OperEx(TlaFunOper.enum, elems @ _*) => + case ex @ OperEx(TlaFunOper.enum, elems @ _*) => val keyEs = elems.zipWithIndex.filter(_._2 % 2 == 0).map(_._1) // pick the even indices (starting with 0) val ctorKeys = keysToStr(state.ex, keyEs.toList) val valueEs = elems.zipWithIndex.filter(_._2 % 2 == 1).map(_._1) // pick the odd indices (starting with 0) @@ -45,13 +41,14 @@ class RecCtorRule(rewriter: SymbStateRewriter) extends RewritingRule { rewriter.rewriteSeqUntilDone(state, valueEs) // compute the types of the field values and then use the type finder val valueCells = newValues.map(newState.arena.findCellByNameEx) - val typeArgs = elems.zipWithIndex.map(p => if (p._2 % 2 == 0) ConstT() else valueCells(p._2 / 2).cellType) - val recordT = rewriter.typeFinder.compute(state.ex, typeArgs: _*).asInstanceOf[RecordT] - // the computed record type may contain additional keys, due to a type annotation + + // the record type may contain more fields than passed in the arguments + val recordT = CellT.fromTypeTag(ex.typeTag).asInstanceOf[RecordT] var arena = newState.arena.appendCell(recordT) val recordCell = arena.topCell // importantly, the record keys that are outside of ctorKeys should not belong to the domain! val extraKeys = recordT.fields.keySet.filter(k => !ctorKeys.contains(k)) + def addExtra(map: Map[String, ArenaCell], key: String) = { // make sure that the key is cached, as it does not appear in the actual expression val (newArena, keyCell) = rewriter.strValueCache.getOrCreate(arena, key) diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/RecFunDefAndRefRule.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/RecFunDefAndRefRule.scala index 7068496d43..3f8bce84df 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/RecFunDefAndRefRule.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/RecFunDefAndRefRule.scala @@ -6,8 +6,8 @@ import at.forsyte.apalache.tla.bmcmt.types._ import at.forsyte.apalache.tla.lir.UntypedPredefs._ import at.forsyte.apalache.tla.lir.convenience.tla import at.forsyte.apalache.tla.lir.oper.TlaFunOper -import at.forsyte.apalache.tla.lir.values.TlaIntSet -import at.forsyte.apalache.tla.lir.{NameEx, OperEx, TlaEx, ValEx} +import at.forsyte.apalache.tla.lir.values.{TlaBoolSet, TlaIntSet} +import at.forsyte.apalache.tla.lir.{BoolT1, FunT1, IntT1, NameEx, OperEx, TlaEx, TlaType1, ValEx} /** * This rule translates the definition of a recursive function. It is similar to CHOOSE. @@ -29,9 +29,10 @@ class RecFunDefAndRefRule(rewriter: SymbStateRewriter) extends RewritingRule { override def apply(state: SymbState): SymbState = { state.ex match { - case OperEx(TlaFunOper.recFunDef, mapEx, NameEx(varName), setEx) => + case ex @ OperEx(TlaFunOper.recFunDef, mapEx, NameEx(varName), setEx) => // note that we only have a single-argument case here, as Desugarer collapses multiple arguments into a tuple - rewriteFunCtor(state, mapEx, varName, setEx) + val funT1 = TlaType1.fromTypeTag(ex.typeTag).asInstanceOf[FunT1] + rewriteFunCtor(state, funT1, mapEx, varName, setEx) case OperEx(TlaFunOper.recFunRef) => val name = TlaFunOper.recFunRef.uniqueName @@ -48,29 +49,30 @@ class RecFunDefAndRefRule(rewriter: SymbStateRewriter) extends RewritingRule { } } - private def rewriteFunCtor(state: SymbState, mapEx: TlaEx, varName: String, setEx: TlaEx) = { + private def rewriteFunCtor(state: SymbState, funT1: FunT1, mapEx: TlaEx, varName: String, setEx: TlaEx) = { // rewrite the set expression into a memory cell var nextState = rewriter.rewriteUntilDone(state.setRex(setEx)) - val domainCell = nextState.asCell - val elemT = domainCell.cellType match { - case FinSetT(et) => et - case t @ _ => throw new RewriterException("Expected a finite set, found: " + t, state.ex) - } - // find the type of the target expression and of the target set - val resultT = rewriter.typeFinder.computeRec(mapEx) - val codomain = - resultT match { - case IntT() => ValEx(TlaIntSet) - case BoolT() => tla.booleanSet().untyped() - case _ => + + val funT = CellT.fromType1(funT1) + val (elemT, codomain) = + funT1 match { + case FunT1(argT, IntT1()) => + (CellT.fromType1(argT), ValEx(TlaIntSet)) + + case FunT1(argT, BoolT1()) => + (CellT.fromType1(argT), ValEx(TlaBoolSet)) + + case FunT1(argT, resultT) => val msg = "A result of a recursive function must belong to Int or BOOLEAN. Found: " + resultT throw new RewriterException(msg, state.ex) } - val funT = - rewriter.typeFinder - .compute(state.ex, resultT, elemT, domainCell.cellType) - .asInstanceOf[FunT] + // one more safety check, as the domain cell can happen to be a powerset or a function set + val domainCell = nextState.asCell + domainCell.cellType match { + case FinSetT(et) => et + case t @ _ => throw new RewriterException("Expected a finite set, found: " + t, state.ex) + } // produce a cell for the function set (no expansion happens there) nextState = rewriter.rewriteUntilDone(nextState.setRex(tla.funSet(domainCell.toNameEx, codomain))) diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/SeqOpsRule.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/SeqOpsRule.scala index ccde4971b1..05d8cfeb8b 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/SeqOpsRule.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/SeqOpsRule.scala @@ -1,9 +1,9 @@ package at.forsyte.apalache.tla.bmcmt.rules import at.forsyte.apalache.tla.bmcmt._ -import at.forsyte.apalache.tla.lir.{OperEx, TlaEx} +import at.forsyte.apalache.tla.lir.{BoolT1, IntT1, OperEx, TlaEx, TlaType1} import at.forsyte.apalache.tla.lir.convenience.tla -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.lir.TypedPredefs._ import at.forsyte.apalache.tla.bmcmt.rules.aux.CherryPick import at.forsyte.apalache.tla.bmcmt.types.CellT import at.forsyte.apalache.tla.lir.oper.TlaSeqOper @@ -15,6 +15,7 @@ import at.forsyte.apalache.tla.lir.oper.TlaSeqOper */ class SeqOpsRule(rewriter: SymbStateRewriter) extends RewritingRule { private val picker = new CherryPick(rewriter) + private val types = Map("i" -> IntT1(), "b" -> BoolT1()) override def isApplicable(symbState: SymbState): Boolean = { symbState.ex match { @@ -31,8 +32,9 @@ class SeqOpsRule(rewriter: SymbStateRewriter) extends RewritingRule { override def apply(state: SymbState): SymbState = { state.ex match { - case OperEx(TlaSeqOper.head, seq) => - rewriter.rewriteUntilDone(state.setRex(tla.appFun(seq, tla.int(1)))) + case ex @ OperEx(TlaSeqOper.head, seq) => + val elemType = TlaType1.fromTypeTag(ex.typeTag) + rewriter.rewriteUntilDone(state.setRex(tla.appFun(seq, tla.int(1)).typed(elemType))) case OperEx(TlaSeqOper.len, seq) => translateLen(state, seq) @@ -62,7 +64,10 @@ class SeqOpsRule(rewriter: SymbStateRewriter) extends RewritingRule { val end = cells.tail.head // increment start, unless it goes over the bound val updatedStart = - tla.ite(tla.lt(start.toNameEx, end.toNameEx), tla.plus(tla.int(1), start.toNameEx), start.toNameEx) + tla + .ite(tla.lt(start.toNameEx ? "i", end.toNameEx ? "i") ? "b", tla.plus(tla.int(1), start.toNameEx ? "i") ? "i", + start.toNameEx ? "i") + .typed(Map("i" -> IntT1(), "b" -> BoolT1()), "i") // increment start nextState = rewriter.rewriteUntilDone(nextState.setRex(updatedStart)) val newStart = nextState.asCell @@ -75,6 +80,7 @@ class SeqOpsRule(rewriter: SymbStateRewriter) extends RewritingRule { private def translateSubSeq(state: SymbState, seq: TlaEx, newStartEx: TlaEx, newEndEx: TlaEx) = { var nextState = state + def rewriteToCell(ex: TlaEx): ArenaCell = { nextState = rewriter.rewriteUntilDone(nextState.setRex(ex)) nextState.asCell @@ -85,20 +91,30 @@ class SeqOpsRule(rewriter: SymbStateRewriter) extends RewritingRule { val start = cells.head val end = cells.tail.head + val expectedStartEx = tla + .plus(start.toNameEx ? "i", tla.minus(newStartEx ? "i", tla.int(1)) ? "i") + .typed(types, "i") // compute the new interval [expectedStart, expectedEnd) - val expectedStart = rewriteToCell(tla.plus(start.toNameEx, tla.minus(newStartEx, tla.int(1)))) - val expectedEnd = rewriteToCell(tla.plus(start.toNameEx, newEndEx)) + val expectedStart = rewriteToCell(expectedStartEx) + val expectedEndEx = tla + .plus(start.toNameEx ? "i", newEndEx) + .typed(types, "i") + val expectedEnd = rewriteToCell(expectedEndEx) // use the computed values, as soon as they do not violate the invariant: // start <= end, start >= oldStart, end <= oldEnd val seqInvariant = rewriteToCell( - tla.and( - tla.le(expectedStart.toNameEx, expectedEnd.toNameEx), - tla.le(start.toNameEx, expectedStart.toNameEx), - tla.le(expectedEnd.toNameEx, end.toNameEx) - )) - - val newStart = rewriteToCell(tla.ite(seqInvariant.toNameEx, expectedStart.toNameEx, tla.int(0))) - val newEnd = rewriteToCell(tla.ite(seqInvariant.toNameEx, expectedEnd.toNameEx, tla.int(0))) + tla + .and( + tla.le(expectedStart.toNameEx ? "i", expectedEnd.toNameEx ? "i") ? "b", + tla.le(start.toNameEx ? "i", expectedStart.toNameEx ? "i") ? "b", + tla.le(expectedEnd.toNameEx ? "i", end.toNameEx ? "i") ? "b" + ) + .typed(types, "b")) + + val newStart = + rewriteToCell(tla.ite(seqInvariant.toNameEx ? "b", expectedStart.toNameEx ? "i", tla.int(0)).typed(types, "i")) + val newEnd = + rewriteToCell(tla.ite(seqInvariant.toNameEx ? "b", expectedEnd.toNameEx ? "i", tla.int(0)).typed(types, "i")) // introduce a new sequence that whose start and end are updated as required nextState = nextState.updateArena(_.appendCell(seqCell.cellType)) @@ -115,7 +131,7 @@ class SeqOpsRule(rewriter: SymbStateRewriter) extends RewritingRule { val start = cells.head val end = cells.tail.head val oldElems = cells.tail.tail - nextState = rewriter.rewriteUntilDone(nextState.setRex(tla.plus(tla.int(1), end.toNameEx))) + nextState = rewriter.rewriteUntilDone(nextState.setRex(tla.plus(tla.int(1), end.toNameEx ? "i").typed(types, "i"))) val newEnd = nextState.asCell nextState = rewriter.rewriteUntilDone(nextState.setRex(newElem)) val newElemCell = nextState.asCell @@ -128,10 +144,16 @@ class SeqOpsRule(rewriter: SymbStateRewriter) extends RewritingRule { nextState = picker.pickByOracle(nextState, oracle, Seq(oldElemCell, newElemCell), nextState.arena.cellTrue().toNameEx) // pick the element from the old sequence: start <= no /\ no < end => oracle = 0 - solverAssert(tla.impl(tla.and(tla.le(start.toNameEx, tla.int(no)), tla.lt(tla.int(no), end.toNameEx)), - oracle.whenEqualTo(nextState, 0))) + solverAssert( + tla + .impl(tla.and(tla.le(start.toNameEx ? "i", tla.int(no)) ? "b", + tla.lt(tla.int(no), end.toNameEx ? "i") ? "b") ? "b", oracle.whenEqualTo(nextState, 0) ? "b") + .typed(types, "b")) // pick the element from the new sequence: no = end => oracle = 1 - solverAssert(tla.impl(tla.eql(tla.int(no), end.toNameEx), oracle.whenEqualTo(nextState, 1))) + solverAssert( + tla + .impl(tla.eql(tla.int(no), end.toNameEx ? "i") ? "b", oracle.whenEqualTo(nextState, 1) ? "b") + .typed(types, "b")) // the other elements are unrestricted, give some freedom to the solver nextState.asCell } @@ -152,7 +174,7 @@ class SeqOpsRule(rewriter: SymbStateRewriter) extends RewritingRule { val start = cells.head val end = cells.tail.head // just return end - start - rewriter.rewriteUntilDone(nextState.setRex(tla.minus(end.toNameEx, start.toNameEx))) + rewriter.rewriteUntilDone(nextState.setRex(tla.minus(end.toNameEx ? "i", start.toNameEx ? "i").typed(types, "i"))) } // Implement concatenation on sequences. This is the most expensive operation on sequences. @@ -168,7 +190,8 @@ class SeqOpsRule(rewriter: SymbStateRewriter) extends RewritingRule { val start = connectedCells.head val end = connectedCells.tail.head val elems = connectedCells.tail.tail - nextState = rewriter.rewriteUntilDone(nextState.setRex(tla.minus(end.toNameEx, start.toNameEx))) + nextState = rewriter.rewriteUntilDone( + nextState.setRex(tla.minus(end.toNameEx ? "i", start.toNameEx ? "i").typed(types, "i"))) val len = nextState.asCell (start, end, len, elems, cell.cellType) } @@ -180,9 +203,10 @@ class SeqOpsRule(rewriter: SymbStateRewriter) extends RewritingRule { // introduce a new sequence nextState = nextState.updateArena(_.appendCell(cellType)) val seq3 = nextState.arena.topCell - nextState = rewriter.rewriteUntilDone(nextState.setRex(tla.int(0))) + nextState = rewriter.rewriteUntilDone(nextState.setRex(tla.int(0).typed())) val start3 = nextState.asCell - nextState = rewriter.rewriteUntilDone(nextState.setRex(tla.plus(len1.toNameEx, len2.toNameEx))) + nextState = rewriter.rewriteUntilDone( + nextState.setRex(tla.plus(len1.toNameEx ? "i", len2.toNameEx ? "i").typed(types, "i"))) val end3 = nextState.asCell val elems1then2 = elems1 ++ elems2 @@ -191,9 +215,14 @@ class SeqOpsRule(rewriter: SymbStateRewriter) extends RewritingRule { // pre-compute integer constants that are used when computing every element // offset2 = N1 + start2 - len1 nextState = rewriter - .rewriteUntilDone(nextState.setRex(tla.minus(tla.plus(tla.int(elems1.size), start2.toNameEx), len1.toNameEx))) + .rewriteUntilDone( + nextState.setRex( + tla + .minus(tla.plus(tla.int(elems1.size) ? "i", start2.toNameEx ? "i") ? "i", len1.toNameEx ? "i") + .typed(types, "i"))) val offset2 = nextState.asCell - nextState = rewriter.rewriteUntilDone(nextState.setRex(tla.plus(len1.toNameEx, len2.toNameEx))) + nextState = rewriter.rewriteUntilDone( + nextState.setRex(tla.plus(len1.toNameEx ? "i", len2.toNameEx ? "i").typed(types, "i"))) val len1plus2 = nextState.asCell // introduce constraints for the i-th element of the resulting sequence @@ -207,20 +236,23 @@ class SeqOpsRule(rewriter: SymbStateRewriter) extends RewritingRule { // If 0 <= i < len1, then require oracle = i + start1, // If len1 <= i < len1 + len2, then require oracle = i + offset2 = i - len1 + N1 + start2, // Otherwise, set oracle to N - val inRange1 = tla.lt(tla.int(i), len1.toNameEx) + val inRange1 = tla.lt(tla.int(i), len1.toNameEx ? "i") ? "b" val inRange2 = - tla.and(tla.le(len1.toNameEx, tla.int(i)), tla.lt(tla.int(i), len1plus2.toNameEx)) + tla.and(tla.le(len1.toNameEx ? "i", tla.int(i)) ? "b", tla.lt(tla.int(i), len1plus2.toNameEx ? "i") ? "b") ? "b" val whenInRange1 = - tla.or(tla.not(inRange1), tla.eql(oracle.intCell.toNameEx, tla.plus(tla.int(i), start1.toNameEx))) + tla.or(tla.not(inRange1) ? "b", + tla.eql(oracle.intCell.toNameEx ? "i", tla.plus(tla.int(i), start1.toNameEx ? "i") ? "i") ? "b") ? "b" val whenInRange2 = - tla.or(tla.not(inRange2), tla.eql(oracle.intCell.toNameEx, tla.plus(tla.int(i), offset2.toNameEx))) + tla.or(tla.not(inRange2) ? "b", + tla.eql(oracle.intCell.toNameEx ? "i", tla.plus(tla.int(i), offset2.toNameEx ? "i") ? "i") ? "b") ? "b" val whenOutOfRange = - tla.or(tla.lt(tla.int(i), len1plus2.toNameEx), tla.eql(oracle.intCell.toNameEx, tla.int(ntotal))) + tla.or(tla.lt(tla.int(i), len1plus2.toNameEx ? "i") ? "b", + tla.eql(oracle.intCell.toNameEx ? "i", tla.int(ntotal)) ? "b") ? "b" - solverAssert(whenInRange1) - solverAssert(whenInRange2) - solverAssert(whenOutOfRange) + solverAssert(whenInRange1.typed(types, "b")) + solverAssert(whenInRange2.typed(types, "b")) + solverAssert(whenOutOfRange.typed(types, "b")) pickedResult } diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/SetCtorRule.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/SetCtorRule.scala index 316a2d4637..133c24fd17 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/SetCtorRule.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/SetCtorRule.scala @@ -1,9 +1,9 @@ package at.forsyte.apalache.tla.bmcmt.rules import at.forsyte.apalache.tla.bmcmt._ -import at.forsyte.apalache.tla.bmcmt.types.FinSetT +import at.forsyte.apalache.tla.bmcmt.types.{CellT, FinSetT} import at.forsyte.apalache.tla.lir.oper.TlaSetOper -import at.forsyte.apalache.tla.lir.{OperEx, TlaEx} +import at.forsyte.apalache.tla.lir.{OperEx, TlaEx, TlaType1} import at.forsyte.apalache.tla.lir.UntypedPredefs._ /** @@ -21,27 +21,26 @@ class SetCtorRule(rewriter: SymbStateRewriter) extends RewritingRule { override def apply(state: SymbState): SymbState = { state.ex match { - case OperEx(TlaSetOper.enumSet, elems @ _*) => - // switch to cell theory - val (newState: SymbState, newEs: Seq[TlaEx]) = + case ex @ OperEx(TlaSetOper.enumSet, elems @ _*) => + val (newState, newEs: Seq[TlaEx]) = rewriter.rewriteSeqUntilDone(state, elems) - val cells = newEs.map(newState.arena.findCellByNameEx) - // compute the set type using the type finder - val elemType = rewriter.typeFinder.compute(state.ex, cells.map(_.cellType): _*) match { + var nextState = newState + val cells = newEs.map(nextState.arena.findCellByNameEx) + val setT = CellT.fromTypeTag(ex.typeTag) + val elemType = setT match { case FinSetT(et) => et case setT @ _ => throw new TypeException("Expected a finite set, found: " + setT, state.ex) } - val arena = newState.arena.appendCell(FinSetT(elemType)) - val newCell = arena.topCell - val newArena = arena.appendHas(newCell, cells: _*) + nextState = nextState.updateArena(_.appendCell(FinSetT(elemType))) + val newSetCell = nextState.arena.topCell + nextState = nextState.updateArena(_.appendHas(newSetCell, cells: _*)) - def addIn(c: ArenaCell): Unit = { - val inExpr = OperEx(TlaSetOper.in, c.toNameEx, newCell.toNameEx) + for (c <- cells) { + val inExpr = OperEx(TlaSetOper.in, c.toNameEx, newSetCell.toNameEx) rewriter.solverContext.assertGroundExpr(inExpr) } - cells.foreach(addIn) - state.setArena(newArena).setRex(newCell.toNameEx) + nextState.setRex(newSetCell.toNameEx) case _ => throw new RewriterException("%s is not applicable".format(getClass.getSimpleName), state.ex) diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/SetFilterRule.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/SetFilterRule.scala index 53c4ab73af..656a547f6b 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/SetFilterRule.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/SetFilterRule.scala @@ -21,7 +21,7 @@ class SetFilterRule(rewriter: SymbStateRewriter) extends RewritingRule { override def apply(state: SymbState): SymbState = { state.ex match { - case OperEx(TlaSetOper.filter, NameEx(varName), setEx, predEx) => + case ex @ OperEx(TlaSetOper.filter, NameEx(varName), setEx, predEx) => // rewrite the set expression into a memory cell var newState = rewriter.rewriteUntilDone(state.setRex(setEx)) newState = newState.asCell.cellType match { @@ -48,8 +48,8 @@ class SetFilterRule(rewriter: SymbStateRewriter) extends RewritingRule { val filteredCellsAndPreds = (potentialCells zip computedPreds) filter (_._2 != NullEx) // get the result type from the type finder - val resultType = rewriter.typeFinder.compute(state.ex, ConstT(), setCell.cellType, BoolT()) - assert(PartialFunction.cond(resultType) { case FinSetT(_) => true }) + val resultType = CellT.fromTypeTag(ex.typeTag) + assert(resultType.isInstanceOf[FinSetT]) // introduce a new set val arena = newState.arena.appendCell(resultType) diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/SetInRule.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/SetInRule.scala index a7ab4aba48..dddf8f55e4 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/SetInRule.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/SetInRule.scala @@ -141,22 +141,16 @@ class SetInRule(rewriter: SymbStateRewriter) extends RewritingRule { private def basicIn(state: SymbState, setCell: ArenaCell, elemCell: ArenaCell, elemType: types.CellT) = { val potentialElems = state.arena.getHas(setCell) - assert(elemCell.cellType == elemType) // otherwise, type finder is incorrect + // The types of the element and the set may slightly differ, but they must be unifiable. + // For instance, [a |-> 1] \in { [a |-> 2], [a |-> 3, b -> "foo"] } + assert(elemCell.cellType.unify(elemType).nonEmpty) if (potentialElems.isEmpty) { - // SE-SET-IN1: the set cell points to no other cell => return false + // the set cell points to no other cell => return false state.setRex(state.arena.cellFalse().toNameEx) } else { var nextState = state.updateArena(_.appendCell(BoolT())) val pred = nextState.arena.topCell.toNameEx - // BUGFIX 06.05.2020: in rare combinations of \A and \in, - // the rule below is not sound - //if (state.arena.isLinkedViaHas(setCell, elemCell)) { - // SE-SET-IN2: the element cell is already in the arena, just check dynamic membership - // rewriter.solverContext.assertGroundExpr(tla.eql(pred, tla.in(elemCell, state.ex))) - // nextState.setTheory(CellTheory()).setRex(pred) - //} else { - // SE-SET-IN3: general case, generate equality constraints, if needed, and use them // cache equality constraints first val eqState = rewriter.lazyEq.cacheEqConstraints(nextState, potentialElems.map((_, elemCell))) diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/SubstRule.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/SubstRule.scala index 84be6b57ca..b2481948d4 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/SubstRule.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/SubstRule.scala @@ -3,6 +3,8 @@ package at.forsyte.apalache.tla.bmcmt.rules import at.forsyte.apalache.tla.bmcmt._ import at.forsyte.apalache.tla.lir.NameEx import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.pp.TlaInputError +import com.typesafe.scalalogging.LazyLogging /** * Substitutes a bound name with a cell. For instance, it substitutes a name that is declared with VARIABLE or CONSTANT, @@ -10,7 +12,7 @@ import at.forsyte.apalache.tla.lir.UntypedPredefs._ * * @author Igor Konnov */ -class SubstRule(rewriter: SymbStateRewriter) extends RewritingRule { +class SubstRule(rewriter: SymbStateRewriter) extends RewritingRule with LazyLogging { override def isApplicable(state: SymbState): Boolean = { state.ex match { case NameEx(x) => @@ -28,7 +30,9 @@ class SubstRule(rewriter: SymbStateRewriter) extends RewritingRule { val cell = state.binding(x) state.setRex(NameEx(cell.toString)) } else { - throw new RewriterException(s"${getClass.getSimpleName}: Variable $x is not assigned a value", state.ex) + logger.error("This error may show up when CONSTANTS are not initialized.") + logger.error("Check the manual: https://apalache.informal.systems/docs/apalache/parameters.html") + throw new TlaInputError(s"${getClass.getSimpleName}: Variable $x is not assigned a value") } case _ => diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/TlcRule.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/TlcRule.scala index 40c4981336..34b667b474 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/TlcRule.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/TlcRule.scala @@ -1,19 +1,19 @@ package at.forsyte.apalache.tla.bmcmt.rules import at.forsyte.apalache.tla.bmcmt._ -import at.forsyte.apalache.tla.bmcmt.types.FailPredT -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.lir.TypedPredefs._ import at.forsyte.apalache.tla.lir.convenience.tla import at.forsyte.apalache.tla.lir.oper.TlcOper import at.forsyte.apalache.tla.lir.values.TlaStr -import at.forsyte.apalache.tla.lir.{OperEx, TlaEx, ValEx} +import at.forsyte.apalache.tla.lir.{BoolT1, FunT1, OperEx, SetT1, TlaEx, TlaType1, TupT1, ValEx} +import com.typesafe.scalalogging.LazyLogging /** * Implements the rules for TLC operators. * * @author Igor Konnov */ -class TlcRule(rewriter: SymbStateRewriter) extends RewritingRule { +class TlcRule(rewriter: SymbStateRewriter) extends RewritingRule with LazyLogging { override def isApplicable(symbState: SymbState): Boolean = { symbState.ex match { case OperEx(TlcOper.print, _, _) => true @@ -31,60 +31,71 @@ class TlcRule(rewriter: SymbStateRewriter) extends RewritingRule { state.setRex(state.arena.cellTrue().toNameEx) case OperEx(TlcOper.assert, value, ValEx(TlaStr(message))) => - rewriteAssert(state, value, message) + // We do not support Assert as it is intrinsically imperative. + // There is an open issue on that: https://github.com/informalsystems/apalache/issues/23 + // Maybe one day we find a way to implement it via slicing. + logger.warn(s"""Met TLC!Assert("$message"). Interpreting it as TRUE.""") + state.setRex(state.arena.cellTrue().toNameEx) - case OperEx(TlcOper.colonGreater, arg, res) => // a :> b - state.setRex(tla.tuple(arg, res)) // just construct a tuple + case ex @ OperEx(TlcOper.colonGreater, arg, res) => // a :> b + // We introduce a singleton function [ x \in {arg} |-> res ]. + val funT1 = TlaType1.fromTypeTag(ex.typeTag).asInstanceOf[FunT1] + // produce a temporary name + val tempName = "__t" + ex.ID + val funEx = tla + .funDef(res, tla.name(tempName).typed(funT1.arg), tla.enumSet(arg).typed(SetT1(funT1.res))) + .typed(funT1) + // translate the new function definition with rewriter + rewriter.rewriteUntilDone(state.setRex(funEx)) - case OperEx(TlcOper.atat, funEx, pairEx) => - // f @@ a :> b, the type checker should take care of types - extendFun(state, funEx, pairEx) + case OperEx(TlcOper.atat, lhs, rhs) => + // lhs @@ rhs, the type checker should take care of types + extendFun(state, lhs, rhs) case _ => throw new RewriterException("%s is not applicable".format(getClass.getSimpleName), state.ex) } } - private def extendFun(state: SymbState, funEx: TlaEx, pairEx: TlaEx): SymbState = { + private def extendFun(state: SymbState, leftFunEx: TlaEx, rightFunEx: TlaEx): SymbState = { def solverAssert = rewriter.solverContext.assertGroundExpr _ - var nextState = rewriter.rewriteUntilDone(state.setRex(funEx)) - val funCell = nextState.asCell - val relation = nextState.arena.getCdm(funCell) - val relationCells = nextState.arena.getHas(relation) - nextState = rewriter.rewriteUntilDone(nextState.setRex(pairEx)) - val newPair = nextState.asCell - nextState = nextState.updateArena(_.appendCell(funCell.cellType)) + + val funT1 = TlaType1.fromTypeTag(leftFunEx.typeTag).asInstanceOf[FunT1] + var nextState = rewriter.rewriteUntilDone(state.setRex(leftFunEx)) + val leftFunCell = nextState.asCell + nextState = rewriter.rewriteUntilDone(nextState.setRex(rightFunEx)) + val rightFunCell = nextState.asCell + // Blindly concatenate both relations. If the domains of the functions intersect, we may produce an unsound encoding. + // As TLC!@@ is used only to decode counterexamples, it should be ok. Otherwise, we would produce a lot of constraints. + val leftRelation = nextState.arena.getCdm(leftFunCell) + val leftPairs = nextState.arena.getHas(leftRelation) + val rightRelation = nextState.arena.getCdm(rightFunCell) + val rightPairs = nextState.arena.getHas(rightRelation) + val jointPairs = leftPairs ++ rightPairs + nextState = nextState.updateArena(_.appendCell(leftFunCell.cellType)) val newFunCell = nextState.arena.topCell - nextState = nextState.updateArena(_.appendCell(relation.cellType)) + nextState = nextState.updateArena(_.appendCell(leftRelation.cellType)) val newRelation = nextState.arena.topCell - nextState = nextState.setArena( - nextState.arena - .setCdm(newFunCell, newRelation) - .appendHas(newRelation, newPair +: relationCells: _*)) + nextState = nextState.updateArena(_.setCdm(newFunCell, newRelation) + .appendHas(newRelation, jointPairs: _*)) + + // As we pass the expressions to SMT, we could use untyped expressions. + // We don't do it, in order to avoid mixing untyped and typed expressions in the same class. + val pairT = TupT1(funT1.arg, funT1.res) + val types = + Map("b" -> BoolT1(), "p" -> pairT, "r" -> SetT1(pairT)) // the new pair unconditionally belongs to the new cell - solverAssert(tla.in(newPair.toNameEx, newRelation.toNameEx)) - for (oldPair <- relationCells) { - val inOld = tla.in(oldPair.toNameEx, relation.toNameEx) - val inNew = tla.in(oldPair.toNameEx, newRelation.toNameEx) - solverAssert(tla.equiv(inNew, inOld)) + for (pair <- leftPairs) { + val inOld = tla.in(pair.toNameEx ? "p", leftRelation.toNameEx ? "r") ? "b" + val inNew = tla.in(pair.toNameEx ? "p", newRelation.toNameEx ? "r") ? "b" + solverAssert(tla.equiv(inNew, inOld).typed(types, "b")) + } + for (pair <- rightPairs) { + val inOld = tla.in(pair.toNameEx ? "p", rightRelation.toNameEx ? "r") ? "b" + val inNew = tla.in(pair.toNameEx ? "p", newRelation.toNameEx ? "r") ? "b" + solverAssert(tla.equiv(inNew, inOld).typed(types, "b")) } nextState.setRex(newFunCell.toNameEx) } - - private def rewriteAssert(state: SymbState, value: TlaEx, message: String) = { - val valueState = rewriter.rewriteUntilDone(state.setRex(value)) - // introduce a new failure predicate - var arena = state.arena.appendCell(FailPredT()) - val failPred = arena.topCell - rewriter.addMessage(failPred.id, "Assertion error: " + message) - val assertion = valueState.ex - val constraint = tla.impl(failPred.toNameEx, tla.not(assertion)) - rewriter.solverContext.assertGroundExpr(constraint) - // return isReachable. If there is a model M s.t. M |= isReachable, then M |= failPred allows us - // to check, whether the assertion is violated or not - valueState - .setArena(arena) - .setRex(state.arena.cellTrue().toNameEx) // if you need a value of a type different from bool, use TypedAssert - } } diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/TupleOrSeqCtorRule.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/TupleOrSeqCtorRule.scala index 87ca18a07e..fbb7895d8b 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/TupleOrSeqCtorRule.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/TupleOrSeqCtorRule.scala @@ -24,14 +24,14 @@ class TupleOrSeqCtorRule(rewriter: SymbStateRewriter) extends RewritingRule { override def apply(state: SymbState): SymbState = { state.ex match { - case OperEx(TlaFunOper.tuple, elems @ _*) => + case ex @ OperEx(TlaFunOper.tuple, elems @ _*) => // switch to cell theory val (stateAfterElems: SymbState, groundElems: Seq[TlaEx]) = rewriter.rewriteSeqUntilDone(state, elems) val cells = groundElems.map(stateAfterElems.arena.findCellByNameEx) - // Get the resulting type from the type finder. It may happen to be a sequence! - val resultT = rewriter.typeFinder.compute(state.ex, cells.map(_.cellType): _*) + // Get the resulting type from the type tag. It may be either a sequence or a tuple. + val resultT = CellT.fromTypeTag(ex.typeTag) resultT match { case tt @ TupleT(_) => createTuple(stateAfterElems, tt, cells) case st @ SeqT(_) => createSeq(stateAfterElems, st, cells) diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/aux/CherryPick.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/aux/CherryPick.scala index 9878fb891c..0cbd58044b 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/aux/CherryPick.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/aux/CherryPick.scala @@ -7,6 +7,8 @@ import at.forsyte.apalache.tla.lir.UntypedPredefs._ import at.forsyte.apalache.tla.lir.values.{TlaIntSet, TlaNatSet} import at.forsyte.apalache.tla.lir.{NameEx, NullEx, TlaEx, ValEx} +import scala.collection.immutable.SortedMap + /** * An element picket that allows us: * @@ -205,46 +207,61 @@ class CherryPick(rewriter: SymbStateRewriter) { * Note that some record fields may have bogus values, since not all the records in the set * are required to have all the keys assigned. That is an unavoidable loophole in the record types. * - * @param cellType a cell type to assign to the picked cell. - * @param state a symbolic state - * @param oracle a variable that stores which element (by index) should be picked, can be unrestricted - * @param records a sequence of records of cellType + * @param cellTypeToIgnore a cell type to assign to the picked cell, this is not always the right type for records + * @param state a symbolic state + * @param oracle a variable that stores which element (by index) should be picked, can be unrestricted + * @param records a sequence of records of cellType * @return a new symbolic state with the expression holding a fresh cell that stores the picked element. */ - def pickRecord(cellType: CellT, state: SymbState, oracle: Oracle, records: Seq[ArenaCell], + def pickRecord(cellTypeToIgnore: CellT, state: SymbState, oracle: Oracle, records: Seq[ArenaCell], elseAssert: TlaEx): SymbState = { - // since we require all records to have exactly the same type, the code became much simpler - rewriter.solverContext.log("; CHERRY-PICK %s FROM [%s] {".format(cellType, records.map(_.toString).mkString(", "))) - val recordType = cellType.asInstanceOf[RecordT] + // the records do not always have the same type, but they do have compatible types + val commonRecordT = findCommonRecordType(records) + rewriter.solverContext + .log("; CHERRY-PICK %s FROM [%s] {".format(commonRecordT, records.map(_.toString).mkString(", "))) - def findKeyIndex(key: String): Int = - recordType.fields.keySet.toList.indexOf(key) + def findKeyIndex(recT: RecordT, key: String): Int = + recT.fields.keySet.toList.indexOf(key) var newState = state + def getKeyOrDefault(record: ArenaCell, key: String): ArenaCell = { + val thisRecT = record.cellType.asInstanceOf[RecordT] + if (thisRecT.fields.contains(key)) { + // this record has the key + val keyIndex = findKeyIndex(thisRecT, key) + newState.arena.getHas(record)(keyIndex) + } else { + // This record does not have the key, but it was mixed with other records and produced a more general type. + // Return a default value. As we are iterating over fields of commonRecordT, we will always find a value. + val valueT = commonRecordT.fields.get(key).get + newState = defaultValueFactory.makeUpValue(newState, valueT) + newState.asCell + } + } + def pickAtPos(key: String): ArenaCell = { - val keyIndex = findKeyIndex(key) - val slice = records.map(c => newState.arena.getHas(c)(keyIndex)) + val slice = records.map(c => getKeyOrDefault(c, key)) newState = pickByOracle(newState, oracle, slice, elseAssert) newState.asCell } // introduce a new record - newState = newState.setArena(newState.arena.appendCell(cellType)) + newState = newState.setArena(newState.arena.appendCell(commonRecordT)) val newRecord = newState.arena.topCell // pick the domain using the oracle. -// newState = pickSet(FinSetT(ConstT()), newState, oracle, records map (r => newState.arena.getDom(r))) - newState = pickRecordDomain(FinSetT(ConstT()), newState, oracle, records map (r => newState.arena.getDom(r))) + newState = pickRecordDomain(commonRecordT, FinSetT(ConstT()), newState, oracle, records, + records map (r => newState.arena.getDom(r))) val newDom = newState.asCell // pick the fields using the oracle - val fieldCells = recordType.fields.keySet.toSeq map pickAtPos + val fieldCells = commonRecordT.fields.keySet.toSeq map pickAtPos // and connect them to the record var newArena = newState.arena.setDom(newRecord, newDom) newArena = newArena.appendHasNoSmt(newRecord, fieldCells: _*) // The awesome property: we do not have to enforce equality of the field values, as this will be enforced by // the rule for the respective element r.key, as it will use the same oracle! - rewriter.solverContext.log(s"; } CHERRY-PICK $newRecord:$cellType") + rewriter.solverContext.log(s"; } CHERRY-PICK $newRecord:$commonRecordT") newState .setArena(newArena) @@ -258,53 +275,85 @@ class CherryPick(rewriter: SymbStateRewriter) { * This optimization prevents the model checker from blowing up in the number of record domains, e.g., in Raft. * * @param domType the goal type - * @param state a symbolic state - * @param oracle the oracle to use + * @param state a symbolic state + * @param oracle the oracle to use * @param domains the domains to pick from * @return a new cell that encodes a picked domain */ - private def pickRecordDomain(domType: CellT, state: SymbState, oracle: Oracle, domains: Seq[ArenaCell]): SymbState = { - // TODO: use elseAssert and Oracle.caseAssertions? + private def pickRecordDomain(commonRecordType: RecordT, domType: CellT, state: SymbState, oracle: Oracle, + records: Seq[ArenaCell], domains: Seq[ArenaCell]): SymbState = { // It often happens that all the domains are actually the same cell. Return this cell. val distinct = domains.distinct if (distinct.size == 1) { state.setRex(distinct.head.toNameEx) } else { - // consistency check: make sure that all the domains consist of exactly the same sets of keys - val keyCells = state.arena.getHas(domains.head) - for (dom <- domains.tail) { - val otherKeyCells = state.arena.getHas(dom) - assert(otherKeyCells.size == keyCells.size, - "inconsistent record domains of size %d and %d".format(keyCells.size, otherKeyCells.size)) - for ((k, o) <- keyCells.zip(otherKeyCells)) { - assert(k == o, s"inconsistent record domains: $k != $o") - } - } + val (newState, keyToCell) = findRecordKeys(state, commonRecordType) // introduce a new cell for the picked domain - var nextState = state.updateArena(_.appendCell(domType)) + var nextState = newState.updateArena(_.appendCell(domType)) val newDom = nextState.arena.topCell + // Add the cells for all potential keys. + // Importantly, they all come from strValueCache, so the same key produces the same cell. + val keyCells = keyToCell.values.toSeq nextState = nextState.updateArena(_.appendHas(newDom, keyCells: _*)) - // once we know that all the keys coincide, constrain membership with SMT + // constrain membership with SMT for ((dom, no) <- domains.zipWithIndex) { - def iffKey(keyCell: ArenaCell) = - tla.equiv(tla.in(keyCell.toNameEx, newDom.toNameEx), tla.in(keyCell.toNameEx, dom.toNameEx)) - - val keysMatch = tla.and(keyCells map iffKey: _*) - rewriter.solverContext.assertGroundExpr(tla.impl(oracle.whenEqualTo(nextState, no), keysMatch)) + val domainCells = nextState.arena.getHas(dom) + + for (keyCell <- keyCells) { + // Although we search over a list, the list size is usually small, e.g., up to 10 elements + if (domainCells.contains(keyCell)) { + // the key belongs to the new domain only if belongs to the domain that is pointed by the oracle + val iff = tla.equiv(tla.in(keyCell.toNameEx, newDom.toNameEx), tla.in(keyCell.toNameEx, dom.toNameEx)) + rewriter.solverContext.assertGroundExpr(tla.impl(oracle.whenEqualTo(nextState, no), iff)) + } else { + // The domain pointed by the oracle does not contain the key + val notInDom = tla.not(tla.in(keyCell.toNameEx, newDom.toNameEx)) + rewriter.solverContext.assertGroundExpr(tla.impl(oracle.whenEqualTo(nextState, no), notInDom)) + } + } } nextState.setRex(newDom.toNameEx) } } + private def findCommonRecordType(records: Seq[ArenaCell]): RecordT = { + var maxRecordType = records.head.cellType + for (rec <- records.tail) { + val recType = rec.cellType + recType.unify(maxRecordType) match { + case Some(commonType) => + maxRecordType = commonType + + case None => + throw new IllegalStateException(s"Found inconsistent records in a set: $maxRecordType and $recType") + } + } + maxRecordType.asInstanceOf[RecordT] + } + + // find the union of the keys for all records, if it exists + private def findRecordKeys(state: SymbState, recordType: RecordT): (SymbState, SortedMap[String, ArenaCell]) = { + val commonKeys = recordType.asInstanceOf[RecordT].fields.keySet + var keyToCell = SortedMap[String, ArenaCell]() + var nextState = state + for (key <- commonKeys) { + val (newArena, cell) = rewriter.strValueCache.getOrCreate(nextState.arena, key) + keyToCell = keyToCell + (key -> cell) + nextState = nextState.setArena(newArena) + } + + (nextState, keyToCell) + } + /** * Implements SE-PICK-SET. * * Note that some record fields may have bogus values, since not all the records in the set * are required to have all the keys assigned. That is an unavoidable loophole in the record types. * - * @param cellType a cell type to assign to the picked cell. - * @param state a symbolic state - * @param oracle a variable that stores which element (by index) should be picked, can be unrestricted + * @param cellType a cell type to assign to the picked cell. + * @param state a symbolic state + * @param oracle a variable that stores which element (by index) should be picked, can be unrestricted * @param memberSets a sequence of sets of cellType * @return a new symbolic state with the expression holding a fresh cell that stores the picked element. */ diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/aux/MapBase.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/aux/MapBase.scala index e09c3aa387..1f75bf3351 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/aux/MapBase.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/aux/MapBase.scala @@ -8,6 +8,7 @@ import at.forsyte.apalache.tla.lir.convenience.tla import at.forsyte.apalache.tla.lir.oper.{TlaOper, TlaSetOper} import at.forsyte.apalache.tla.lir.{OperEx, TlaEx} import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.pp.TlaInputError import scala.collection.mutable @@ -35,6 +36,9 @@ class MapBase(rewriter: SymbStateRewriter) { case FinSetT(elemType) => (setCell, elemType) + case InfSetT(elemType) => + throw new TlaInputError(s"Found a set map over an infinite set of $elemType. Not supported.") + case tp @ _ => throw new NotImplementedError("A set map over %s is not implemented".format(tp)) } } @@ -43,7 +47,7 @@ class MapBase(rewriter: SymbStateRewriter) { val elemsOfSets = setsAsCells.map(nextState.arena.getHas) val setLimits = elemsOfSets.map(_.size - 1) // find the types of the target expression and of the target set - val targetMapT = rewriter.typeFinder.computeRec(mapEx) + val targetMapT = CellT.fromTypeTag(mapEx.typeTag) val targetSetT = FinSetT(targetMapT) nextState = nextState.updateArena(_.appendCell(targetSetT)) diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/deprecated/CaseRule.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/deprecated/CaseRule.scala index d824762ef0..366fe365bb 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/deprecated/CaseRule.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/deprecated/CaseRule.scala @@ -34,14 +34,14 @@ class CaseRule(rewriter: SymbStateRewriter) extends RewritingRule { val iteWaterfall = revGuardsAndActions.foldLeft(otherEx)(decorateWithIf) rewriter.rewriteUntilDone(state.setRex(iteWaterfall)) - case OperEx(TlaControlOper.caseNoOther, args @ _*) => + case ex @ OperEx(TlaControlOper.caseNoOther, args @ _*) => // first, rewrite all the arguments val (newState: SymbState, newArgs: Seq[TlaEx]) = rewriter.rewriteSeqUntilDone(state, args) val revGuardsAndActions = mkGuardsAndActions(newArgs) val cells = newArgs.map(newState.arena.findCellByNameEx) // get the expression type from the type finder (use the original expression as it could have been annotated!) - val resultType = rewriter.typeFinder.compute(state.ex, cells.map(_.cellType): _*) + val resultType = CellT.fromTypeTag(ex.typeTag) // place ASSERT(FALSE) instead of other val assertState = new TypedAssert(rewriter) .typedAssert(newState, resultType, tla.bool(false), "It may happen that no guard in CASE is applicable") diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/trex/ExecutionContext.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/trex/ExecutionContext.scala index 78c33f1883..7ecc40c32a 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/trex/ExecutionContext.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/trex/ExecutionContext.scala @@ -3,7 +3,6 @@ package at.forsyte.apalache.tla.bmcmt.trex import at.forsyte.apalache.tla.bmcmt.SymbStateRewriter import at.forsyte.apalache.tla.bmcmt.rewriter.Recoverable import at.forsyte.apalache.tla.bmcmt.smt.SolverContext -import at.forsyte.apalache.tla.bmcmt.types.{CellT, TypeFinder} /** * A context that is used by TransitionExecutor. By default, a context is not thread-safe, @@ -17,7 +16,6 @@ trait ExecutionContext[SnapshotT] extends Recoverable[SnapshotT] { def rewriter: SymbStateRewriter - def typeFinder: TypeFinder[CellT] = rewriter.typeFinder def solver: SolverContext = rewriter.solverContext /** diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/trex/IncrementalExecutionContext.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/trex/IncrementalExecutionContext.scala index 52a0bba18c..74ba8f58ae 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/trex/IncrementalExecutionContext.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/trex/IncrementalExecutionContext.scala @@ -19,7 +19,7 @@ class IncrementalExecutionContext(val rewriter: SymbStateRewriter) override def snapshot(): IncrementalExecutionContextSnapshot = { val level = rewriter.contextLevel rewriter.push() - IncrementalExecutionContextSnapshot(level, typeFinder.varTypes) + IncrementalExecutionContextSnapshot(level) } /** @@ -44,7 +44,6 @@ class IncrementalExecutionContext(val rewriter: SymbStateRewriter) } rewriter.pop(nPops) - rewriter.typeFinder.reset(snapshot.varTypes) } /** @@ -54,6 +53,5 @@ class IncrementalExecutionContext(val rewriter: SymbStateRewriter) override def dispose(): Unit = { // dispose the rewriter, which will, in turn, dispose the solver rewriter.dispose() - // nothing to dispose in the type finder } } diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/trex/IncrementalExecutionContextSnapshot.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/trex/IncrementalExecutionContextSnapshot.scala index 14cbcc875b..5fe2c9cb01 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/trex/IncrementalExecutionContextSnapshot.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/trex/IncrementalExecutionContextSnapshot.scala @@ -9,10 +9,10 @@ import at.forsyte.apalache.tla.bmcmt.types.CellT * * @author Igor Konnov */ -class IncrementalExecutionContextSnapshot(var rewriterLevel: Int, val varTypes: Map[String, CellT]) {} +class IncrementalExecutionContextSnapshot(var rewriterLevel: Int) {} object IncrementalExecutionContextSnapshot { - def apply(rewriterDepth: Int, varTypes: Map[String, CellT]): IncrementalExecutionContextSnapshot = { - new IncrementalExecutionContextSnapshot(rewriterDepth, varTypes) + def apply(rewriterDepth: Int): IncrementalExecutionContextSnapshot = { + new IncrementalExecutionContextSnapshot(rewriterDepth) } } diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/trex/OfflineExecutionContext.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/trex/OfflineExecutionContext.scala index 14aea92bfa..83c4f96e9d 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/trex/OfflineExecutionContext.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/trex/OfflineExecutionContext.scala @@ -22,7 +22,7 @@ class OfflineExecutionContext(var rewriter: SymbStateRewriter) val rs = rewriter.snapshot() val smtLog = rewriter.solverContext.asInstanceOf[RecordingSolverContext].extractLog() logger.debug("Offline snapshot has %d entries".format(smtLog.lengthRec)) - new OfflineExecutionContextSnapshot(rewriter.solverContext.config, rs, smtLog, typeFinder.varTypes) + new OfflineExecutionContextSnapshot(rewriter.solverContext.config, rs, smtLog) } /** @@ -41,12 +41,11 @@ class OfflineExecutionContext(var rewriter: SymbStateRewriter) override def recover(snapshot: OfflineExecutionContextSnapshot): Unit = { val solver = RecordingSolverContext.createZ3(Some(snapshot.smtLog), snapshot.solverConfig) // TODO: issue #105, remove references to SolverContext, so recovery becomes less of a hack - val newRewriter = new SymbStateRewriterImpl(solver, typeFinder, rewriter.exprGradeStore) + val newRewriter = new SymbStateRewriterImpl(solver, rewriter.exprGradeStore) newRewriter.formulaHintsStore = rewriter.formulaHintsStore newRewriter.config = rewriter.config newRewriter.recover(snapshot.rewriterSnapshot) newRewriter.solverContext = solver - newRewriter.typeFinder.reset(snapshot.varTypes) rewriter = newRewriter } diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/trex/OfflineExecutionContextSnapshot.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/trex/OfflineExecutionContextSnapshot.scala index bcb52d739c..006af3a99d 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/trex/OfflineExecutionContextSnapshot.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/trex/OfflineExecutionContextSnapshot.scala @@ -2,7 +2,6 @@ package at.forsyte.apalache.tla.bmcmt.trex import at.forsyte.apalache.tla.bmcmt.rewriter.SymbStateRewriterSnapshot import at.forsyte.apalache.tla.bmcmt.smt.{SmtLog, SolverConfig} -import at.forsyte.apalache.tla.bmcmt.types.CellT /** * A snapshot when using a non-incremental SMT solver. @@ -10,4 +9,4 @@ import at.forsyte.apalache.tla.bmcmt.types.CellT * @author Igor Konnov */ class OfflineExecutionContextSnapshot(val solverConfig: SolverConfig, val rewriterSnapshot: SymbStateRewriterSnapshot, - val smtLog: SmtLog, val varTypes: Map[String, CellT]) {} + val smtLog: SmtLog) {} diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/trex/TransitionExecutorImpl.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/trex/TransitionExecutorImpl.scala index ce2a3553ef..c6ba6ff9b7 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/trex/TransitionExecutorImpl.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/trex/TransitionExecutorImpl.scala @@ -57,7 +57,6 @@ class TransitionExecutorImpl[ExecCtxT](consts: Set[String], vars: Set[String], c throw new IllegalStateException(s"initializeConstants should be called only against the initial state") } logger.debug("Initializing CONSTANTS") - inferTypes(constInit) lastState = ctx.rewriter.rewriteUntilDone(lastState.setRex(constInit)) // check, whether all constants have been assigned val shiftedBinding = lastState.binding.shiftBinding(Set.empty) @@ -66,8 +65,6 @@ class TransitionExecutorImpl[ExecCtxT](consts: Set[String], vars: Set[String], c throw new IllegalStateException("CONSTANTS are not initialized: " + diff.mkString(", ")) } - shiftTypes(Set.empty) // treat constants as variables - lastState = lastState.setBinding(shiftedBinding) // update the execution stack in place, as we are dealing with the special case of constant initialization @@ -91,7 +88,6 @@ class TransitionExecutorImpl[ExecCtxT](consts: Set[String], vars: Set[String], c throw new IllegalStateException(s"prepareTransition is called for $transitionNo two times") } logger.debug(s"Step #${stepNo}, transition #${transitionNo}") - inferTypes(transitionEx) ctx.rewriter.solverContext.log( "; ------- STEP: %d, SMT LEVEL: %d TRANSITION: %d {" .format(stepNo, ctx.rewriter.contextLevel, transitionNo)) @@ -193,7 +189,6 @@ class TransitionExecutorImpl[ExecCtxT](consts: Set[String], vars: Set[String], c * though it may be an action expression. */ override def assertState(assertion: TlaEx): Unit = { - inferTypes(assertion) val nextState = ctx.rewriter.rewriteUntilDone(lastState.setRex(assertion)) ctx.rewriter.solverContext.assertGroundExpr(nextState.ex) lastState = nextState.setRex(lastState.ex) // propagate the arena and binding, but keep the old expression @@ -278,13 +273,7 @@ class TransitionExecutorImpl[ExecCtxT](consts: Set[String], vars: Set[String], c lastState = lastState .setBinding(lastState.binding.shiftBinding(consts)) .setRex(lastState.arena.cellTrue().toNameEx) - // save the types of the cells that are bound to the previous variables types, - // so the transition executor can process assertions over state variables of the whole execution - vars - .map(name => lastState.binding(name)) - .foreach(cell => ctx.typeFinder.extendWithCellType(cell)) // that is the result of this step - shiftTypes(consts) // importantly, clean the action-level caches, so the new variables are not mapped to the old variables ctx.rewriter.exprCache.disposeActionLevel() // clean the prepared transitions @@ -363,31 +352,4 @@ class TransitionExecutorImpl[ExecCtxT](consts: Set[String], vars: Set[String], c private def pushLastState(oracle: Oracle): Unit = { revStack = (lastState.binding.shiftBinding(consts), oracle) :: revStack } - - // infer the types and throw an exception if type inference has failed - private def inferTypes(expr: TlaEx): Unit = { -// logger.debug("Inferring types...") - ctx.typeFinder.inferAndSave(expr) - if (ctx.typeFinder.typeErrors.nonEmpty) { - throw new TypeInferenceException(ctx.typeFinder.typeErrors) - } - } - - /** - * Remove the non-primed variables (except provided constants) - * and rename the primed variables to their non-primed versions. - * After that, remove the type finder to contain the new types only. - */ - private def shiftTypes(constants: Set[String]): Unit = { - val types = ctx.typeFinder.varTypes - // keep the types of prime variables, cells, and constants - def keep(name: String): Boolean = { - name.endsWith("'") || ArenaCell.isValidName(name) || constants.contains(name) - } - val nextTypes = - types - .filter(p => keep(p._1)) - .map(p => (p._1.stripSuffix("'"), p._2)) - ctx.typeFinder.reset(nextTypes) - } } diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/types/AnnotationParser.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/types/AnnotationParser.scala index 6f55ec1b06..f39bb48d86 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/types/AnnotationParser.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/types/AnnotationParser.scala @@ -15,6 +15,7 @@ import scala.collection.immutable.SortedMap * * @author Igor Konnov */ +@deprecated("This is a parser for the old type annotations. Do not use it!") object AnnotationParser { /** diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/types/TypeFinder.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/types/TypeFinder.scala deleted file mode 100644 index e2892dae41..0000000000 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/types/TypeFinder.scala +++ /dev/null @@ -1,93 +0,0 @@ -package at.forsyte.apalache.tla.bmcmt.types - -import at.forsyte.apalache.tla.bmcmt.{ArenaCell, TypeException} -import at.forsyte.apalache.tla.lir.TlaEx - -import scala.collection.immutable.{Map, SortedMap} - -/** - * A diagnostic message that is added in a list of errors. - * - * @param origin the expression that caused the type error - * @param explanation the explanation - */ -class TypeInferenceError(val origin: TlaEx, val explanation: String) - -/** - * A general interface to a type inference engine. Check the description in docs/types-api.md. - * - * @tparam T the base class of the type system - * @see CellT - * @author Igor Konnov - */ -trait TypeFinder[T] { - - /** - * Given a TLA+ expression, reconstruct the types and store them in an internal storage. - * If the expression is not well-typed, this method will not throw TypeInferenceError, - * but will collect a list of error that can be access with getTypeErrors. - * - * @param e a TLA+ expression. - * @return Some(type), if successful, and None otherwise - */ - def inferAndSave(e: TlaEx): Option[T] - - /** - * Retrieve the type errors from the latest call to inferAndSave. - * - * @return a list of type errors - */ - def typeErrors: Seq[TypeInferenceError] - - /** - * Given a TLA+ expression and the types of its arguments, compute the resulting type, if possible. - * This function uses the types that were pre-computed by inferAndSave. It should also work for arbitrary - * expressions, as soon as they can be unambiguously typed with the previously stored type information - * and the given arguments. - * - * @param e a TLA+ expression - * @param argTypes the types of the arguments. - * @return the resulting type, if it can be computed - * @throws TypeException , if the type cannot be computed. - */ - def compute(e: TlaEx, argTypes: T*): T - - /** - * Call compute recursively to compute the type of a given expression. This function is expensive, - * use it only when absolutely necessary. If the expression is referring to variables, inferAndSave should have - * been called before. - * - * @param ex a TLA+ expression - * @return the resulting type - */ - def computeRec(ex: TlaEx): CellT - - /** - * Get the types of the variables that are computed by inferAndSave. The method must return the types of - * the global variables (VARIABLE and CONSTANT) and it may return types of the bound variables. - * - * @return a mapping of names to types - */ - def varTypes: SortedMap[String, T] - - /** - * Restore variable types from a map. This method does not update type annotations. - * - * @param newVarTypes a mapping of names to types - */ - def varTypes_(newVarTypes: SortedMap[String, CellT]): Unit - - /** - * Record the cell name and its type. - * - * @param cell an arena cell - */ - def extendWithCellType(cell: ArenaCell): Unit - - /** - * Forget all computed types and introduce types for the variables. You can call inferAndSave after that. - * - * @param varTypes types of the global variables (VARIABLE and CONSTANT) - */ - def reset(varTypes: Map[String, T]): Unit -} diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/types/eager/TrivialTypeFinder.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/types/eager/TrivialTypeFinder.scala deleted file mode 100644 index 450c5aff17..0000000000 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/types/eager/TrivialTypeFinder.scala +++ /dev/null @@ -1,1044 +0,0 @@ -package at.forsyte.apalache.tla.bmcmt.types.eager - -import at.forsyte.apalache.tla.bmcmt.ArenaCell -import at.forsyte.apalache.tla.bmcmt.rewriter.Recoverable -import at.forsyte.apalache.tla.bmcmt.types._ -import at.forsyte.apalache.tla.lir._ -import at.forsyte.apalache.tla.lir.oper._ -import at.forsyte.apalache.tla.lir.transformations.TransformationListener -import at.forsyte.apalache.tla.lir.values._ - -import scala.collection.immutable.{Map, SortedMap} - -/** - *

An eager type finder that propagates types from the leaves to the root. - * As it can easily fail to find a type, the user has to write type annotations. - * In contrast, to our first type inference approach, this engine is not trying to be - * smart at all, and it is not doing any form of unification.

- * - *

This class assumes that some pre-processing has been made:

- * - *
    - *
  1. The definitions of all user-defined operators have been expanded (no recursive operators),
  2. - *
  3. All variable names are unique, including the bound variables.
  4. - *
- * - *

TrivialTypeFinder implements TransformationListener, so it propagates type annotations on the expressions - * after modifications.

- * - * @author Igor Konnov - */ -class TrivialTypeFinder - extends TypeFinder[CellT] with TransformationListener with Serializable with Recoverable[TrivialTypeSnapshot] { - private var _varTypes: SortedMap[String, CellT] = SortedMap() - private var _typeAnnotations: Map[UID, CellT] = Map() - private var _typeErrors: Seq[TypeInferenceError] = Seq() - - /** - * Get the types of the variables that are computed by inferAndSave. The method must return the types of - * the global variables (VARIABLE and CONSTANT) and it may return types of the bound variables. - * - * @return a mapping of names to types - */ - override def varTypes: SortedMap[String, CellT] = _varTypes - - /** - * Restore variable types from a map. - * - * @param newVarTypes a mapping of names to types - */ - override def varTypes_(newVarTypes: SortedMap[String, CellT]): Unit = { - _varTypes = newVarTypes - } - - /** - * Forget all computed types and introduce types for the variables. You can call inferAndSave after that. - * - * @param newVarTypes types of the global variables (VARIABLE and CONSTANT) - */ - override def reset(newVarTypes: Map[String, CellT]): Unit = { - _varTypes = SortedMap(newVarTypes.toSeq: _*) - _typeAnnotations = Map() - _typeErrors = Seq() - } - - /** - * Take a snapshot and return it - * - * @return the snapshot - */ - override def snapshot(): TrivialTypeSnapshot = { - new TrivialTypeSnapshot(_typeAnnotations, _varTypes) - } - - /** - * Recover a previously saved snapshot (not necessarily saved by this object). - * - * @param shot a snapshot - */ - override def recover(shot: TrivialTypeSnapshot): Unit = { - _typeAnnotations = shot.typeAnnotations - _varTypes = shot.varTypes - } - - /** - * Record the cell name and its type. - * - * @param cell an arena cell - */ - def extendWithCellType(cell: ArenaCell): Unit = { - _varTypes += cell.toString -> cell.cellType - } - - override def onTransformation(originalEx: TlaEx, newEx: TlaEx): Unit = { - _typeAnnotations.get(originalEx.ID) match { - // propagate type annotations - case Some(tp) => _typeAnnotations += newEx.ID -> tp - case _ => () - } - } - - override def onDeclTransformation(originalDecl: TlaDecl, newDecl: TlaDecl): Unit = { - // ignore transformations of declarations - } - - /** - * Given a TLA+ expression, reconstruct the types and store them in an internal storage. - * If the expression is not well-typed, diagnostic messages can be accessed with getTypeErrors. - * The main goal of this method is to assign types to the free and bound variables - * as we do not consider operators. (We allow nullary LET-IN operators though). - * - * @param expr a TLA+ expression. - * @return Some(type), if successful, and None otherwise - */ - override def inferAndSave(expr: TlaEx): Option[CellT] = { - // This class implements a very simple form of type inference from bottom to top. - // As soon as we cannot infer types, we complain that the type annotations are not good enough. - expr match { - // LET A == ... IN - // LET B == ... IN - // ... - // IN Z - case letIn @ LetInEx(body, defs @ _*) => - def inferDefResultType(d: TlaOperDecl): Unit = { - if (d.formalParams.nonEmpty) { - // This is a critical error in our code, which is not a type inference error - throw new IllegalStateException( - s"Found a non-constant LET-IN definition ${d.name}, which should have been expanded") - } else { - val resT = inferAndSave(d.body) - // Bind the let name to the computed type of the result. - // XXX: It is not a type of a variable, which may confuse the model checker. - _varTypes += d.name -> resT.getOrElse(UnknownT()) - } - } - - defs foreach inferDefResultType - inferAndSave(body) // body may use the types of the let definitions - - // x' = e - // x' \in S - case OperEx(BmcOper.assign, OperEx(TlaActionOper.prime, NameEx(varName)), rhs) => - def assignTypeAndReturnBool(assignedType: CellT): Option[CellT] = { - val primedVar = varName + "'" - if (_varTypes.contains(primedVar)) { - if (_varTypes(primedVar) != assignedType) { - error(expr, - "Assigning a type %s, while assigned type %s earlier" - .format(assignedType, _varTypes(primedVar))) - } - } else { - _varTypes = _varTypes + (primedVar -> assignedType) - } - Some(BoolT()) - } - - inferAndSave(rhs) match { - case Some(tp) => - assignTypeAndReturnBool(tp) - case tp @ None => - errorThenNone(rhs, "Expected a type, found: " + tp) - } - - // { x \in S: e } - case OperEx(TlaSetOper.filter, NameEx(x), set, pred) => - inferAndSave(set) match { - case Some(setT @ FinSetT(elemT)) => - assert(!_varTypes.contains(x)) - _varTypes = _varTypes + (x -> elemT) - val predT = inferAndSave(pred) - if (predT.contains(BoolT())) { - Some(setT) - } else { - errorThenNone(pred, "Expected a Boolean, found: " + predT) - } - - case tp @ _ => - _varTypes = _varTypes + (x -> UnknownT()) // otherwise, the type rewriter may throw an exception - errorThenNone(set, "Expected a finite set, found: " + tp) - } - - // {e : x \in S} - case OperEx(TlaSetOper.map, mapEx, varsAndSets @ _*) => - val names = varsAndSets.zipWithIndex.collect { case (NameEx(n), i) if i % 2 == 0 => n } - val sets = varsAndSets.zipWithIndex.collect { case (e, i) if i % 2 == 1 => e } - - def bind(name: String, set: TlaEx): Unit = { - inferAndSave(set) match { - case Some(setT @ FinSetT(elemT)) => - assert(!_varTypes.contains(name)) - _varTypes = _varTypes + (name -> elemT) - - case Some(PowSetT(setT @ FinSetT(_))) => - assert(!_varTypes.contains(name)) - _varTypes = _varTypes + (name -> setT) - - case tp @ _ => - _varTypes = _varTypes + (name -> UnknownT()) // otherwise, the type rewriter may throw an exception - errorThenNone(set, "Expected a finite set, found: " + tp) - } - } - - names.zip(sets) foreach (bind _).tupled - Some(FinSetT(inferAndSave(mapEx).getOrElse(UnknownT()))) - - // [x \in S |-> e] - case OperEx(op, funEx, varsAndSets @ _*) if op == TlaFunOper.funDef || op == TlaFunOper.recFunDef => - val names = varsAndSets.zipWithIndex.collect { case (NameEx(n), i) if i % 2 == 0 => n } - val sets = varsAndSets.zipWithIndex.collect { case (e, i) if i % 2 == 1 => e } - - def bind(name: String, set: TlaEx): Unit = { - inferAndSave(set) match { - case Some(setT @ FinSetT(elemT)) => - assert(!_varTypes.contains(name)) - _varTypes = _varTypes + (name -> elemT) - - case tp @ _ => - _varTypes = - _varTypes + (name -> UnknownT()) // otherwise, the type rewriter throws an exception 10 lines below - errorThenNone(set, "Expected a finite set, found: " + tp) - } - } - - names.zip(sets) foreach (bind _).tupled - val resT = inferAndSave(funEx).getOrElse(UnknownT()) - val domT = - if (names.length == 1) { - // a function of one argument - FinSetT(_varTypes(names.head)) - } else { - // a function of multiple arguments is a function from a Cartesian product to the result type - FinSetT(TupleT(names.map(_varTypes(_)))) - } - Some(FunT(domT, resT)) - - // exists, forall, or CHOOSE - case OperEx(op, NameEx(x), set, pred) - if op == TlaBoolOper.exists || op == TlaBoolOper.forall || op == TlaOper.chooseBounded => - // infer result by having computed the set type (below) - def inferResult(elemT: CellT) = { - assert(!_varTypes.contains(x)) - _varTypes = _varTypes + (x -> elemT) - val predT = inferAndSave(pred) - if (predT.contains(BoolT())) { - if (op == TlaOper.chooseBounded) { - Some(elemT) // CHOOSE - } else { - Some(BoolT()) // exists/forall - } - } else { - errorThenNone(pred, "Expected a Boolean, found: " + predT) - } - } - - // first, compute the set type and then the result - inferAndSave(set) match { - case Some(setT @ FinSetT(elemT)) => - inferResult(elemT) - - case Some(setT @ InfSetT(elemT)) if op == TlaBoolOper.exists => - // pass an infinite set, as it might be replaced with a constant, due to Skolemization - inferResult(elemT) - - case Some(_ @InfSetT(elemT)) if op == TlaOper.chooseBounded || op == TlaBoolOper.forall => - // complain right away - val name = if (op == TlaOper.chooseBounded) "CHOOSE" else "\\A" - errorThenNone(set, s"$name over an infinite set") - - case tp @ _ => - _varTypes = _varTypes + (x -> UnknownT()) // otherwise, the type rewriter may throw an exception - errorThenNone(set, "Expected a finite set, found: " + tp) - } - - // a type annotation for a recursive function call - case OperEx(BmcOper.withType, ex @ OperEx(TlaFunOper.recFunRef), annot) => - val annotT = AnnotationParser.fromTla(annot) - _typeAnnotations += (ex.ID -> annotT) - Some(annotT) - - // a type annotation - case OperEx(BmcOper.withType, ex, annot) => - val exT = inferAndSave(ex) - val annotT = AnnotationParser.fromTla(annot) - val unifier = unifyOption(Some(annotT), exT) - if (unifier.isDefined) { - // save the type annotation and return the type - _typeAnnotations += (ex.ID -> unifier.get) - unifier - } else { - val exTStr = if (exT.isDefined) exT.get.toString else None.toString - errorThenNone(annot, - s"No unifier for type $annotT and type $exTStr (from type annotation $annot and expression $ex)") - } - - case OperEx(TlaActionOper.prime, NameEx(name)) => - val primed = name + "'" - val result = _varTypes.get(primed) - if (result.isEmpty) { - errorThenNone(expr, s"Failed to find type of variable $primed") - } - result - - case ex @ OperEx(TlaActionOper.prime, arg) => - errorThenNone(ex, "Expected a name under ', found: " + arg) - - // other operators - case OperEx(_, args @ _*) => - val argTypes = args.map(inferAndSave) - if (argTypes.forall(_.isDefined)) { - Some(compute(expr, argTypes.map(_.get): _*)) - } else { - None - } - - case NameEx(name) => - var result = _varTypes.get(name) - if (result.isEmpty) { - error(expr, "Failed to find type of variable " + name) - } - result - - case ValEx(_) => - Some(compute(expr)) - - case _ => - None - } - } - - /** - * Retrieve the type errors from the latest call to inferAndSave. - * - * @return a list of type errors - */ - override def typeErrors: Seq[TypeInferenceError] = _typeErrors - - /** - * Call compute recursively to compute the type of a given expression. This function is expensive, - * use it only when absolutely necessary. - * - * TODO: remove this function and use inferAndSave instead? - * - * @param ex a TLA+ expression - * @return the resulting type - */ - override def computeRec(ex: TlaEx): CellT = ex match { - case OperEx(BmcOper.withType, annotated, _) => - // a pre-computed type annotation overrides the type info - assert(_typeAnnotations.contains(annotated.ID)) // otherwise, the engine is broken - _typeAnnotations(annotated.ID) - - case OperEx(TlaActionOper.prime, NameEx(_)) => - // do not recurse in prime, as the type of a primed variable should be computed directly - compute(ex) - - case LetInEx(body, _*) => - // compute the type of body, assuming that the types of the bound variables were computed by inferAndSave - computeRec(body) - - case OperEx(_, args @ _*) => - compute(ex, args map computeRec: _*) - - case _ => - compute(ex) - } - - /** - * Given a TLA+ expression and the types of its arguments, compute the resulting type, if possible. - * - * @param ex a TLA+ expression - * @param argTypes the types of the arguments. - * @return the resulting type, if it can be computed - * @throws TypeInferenceError if the type cannot be computed. - */ - override def compute(ex: TlaEx, argTypes: CellT*): CellT = { - if (_typeAnnotations.contains(ex.ID)) { - // this expression has been annotated with a type - _typeAnnotations(ex.ID) - } else { - // chain partial functions to handle different types of operators with different functions - val handlers = - (computeValues - :: computeBasicOps(argTypes) - :: computeBoolOps(argTypes) - :: computeIntOps(argTypes) - :: computeControlOps(argTypes) - :: computeSetCtors(argTypes) - :: computeFunCtors(argTypes) - :: computeSetOps(argTypes) - :: computeFunOps(argTypes) - :: computeFunApp(argTypes) - :: computeFiniteSetOps(argTypes) - :: computeSeqOps(argTypes) - :: computeMiscOps(argTypes) - :: notImplemented :: Nil) reduceLeft (_ orElse _) - handlers(ex) - } - } - - private def computeValues: PartialFunction[TlaEx, CellT] = { - case ValEx(TlaInt(_)) => - IntT() - - case ValEx(TlaBool(_)) => - BoolT() - - case ValEx(TlaStr(_)) => - ConstT() - - case ValEx(TlaNatSet) => - InfSetT(IntT()) - - case ValEx(TlaIntSet) => - InfSetT(IntT()) - } - - private def computeBasicOps(argTypes: Seq[CellT]): PartialFunction[TlaEx, CellT] = { - case ne @ NameEx(name) => - _varTypes - .get(name) - .orElse(Some(error(ne, "No type assigned to " + name))) - .get - - case app @ OperEx(TlaOper.apply, NameEx(opName)) => - _varTypes - .get(opName.toString) // String.toString ?? - .orElse(Some(error(app, "No type assigned to " + opName))) - .get - - case OperEx(TlaOper.apply, opName, _*) => - throw new IllegalStateException(s"Unexpected operator call to $opName. Report a bug!") - - case ne @ OperEx(TlaActionOper.prime, NameEx(name)) => - val primed = name + "'" - _varTypes - .get(primed) - .orElse(Some(error(ne, "No type assigned to " + primed))) - .get - - case ex @ OperEx(op, _, _) if op == TlaOper.eq || op == TlaOper.ne => - expectEqualTypes(ex, argTypes: _*) - BoolT() - - case ex @ OperEx(op @ TlaOper.chooseBounded, x, set, pred) => - val xType = argTypes.head - val setType = argTypes.tail.head - val predType = argTypes.tail.tail.head - setType match { - case FinSetT(elemT) => - expectType(elemT, x, xType) - expectType(BoolT(), pred, predType) - elemT - - case _ => - errorUnexpected(ex, op.name, argTypes) - } - - case ex @ OperEx(op @ TlaOper.chooseUnbounded, x, pred) => - val xType = argTypes.head - val predType = argTypes.tail.head - expectType(BoolT(), pred, predType) - xType - - case ex @ OperEx(op @ TlaOper.chooseIdiom, _) => - argTypes match { - case Seq(FinSetT(elemT)) => - elemT - - case _ => - errorUnexpected(ex, op.name, argTypes) - } - - case ex @ OperEx(op @ TlaOper.label, _, _, _*) => - val decoratedExprType = argTypes.head - val nameAndArgTypes = argTypes.tail - nameAndArgTypes.foreach(expectType(ConstT(), ex, _)) - decoratedExprType - } - - private def computeSetCtors(argTypes: Seq[CellT]): PartialFunction[TlaEx, CellT] = { - case ex @ OperEx(TlaSetOper.enumSet, _*) => - if (argTypes.isEmpty) { - // This case typically causes problems, as any operation with - // a concrete type would fail. One has to use a type annotation. - FinSetT(UnknownT()) - } else { - expectEqualTypes(ex, argTypes: _*) - FinSetT(argTypes.head) - } - - case ex @ OperEx(op @ TlaSetOper.funSet, _, _) => - argTypes match { - case Seq(FinSetT(argT), FinSetT(resT)) => - // FinT expects the types of the domain and the result (not of the co-domain!) - FinSetT(FunT(FinSetT(argT), resT)) - - case Seq(FinSetT(argT), InfSetT(resT)) => - // a result from an infinite domain is ok, as soon as we are not unfolding this domain - FinSetT(FunT(FinSetT(argT), resT)) - - case _ => errorUnexpected(ex, op.name, argTypes) - } - - case ex @ OperEx(TlaSetOper.recSet, args @ _*) => - assert(argTypes.nonEmpty) - val fieldNames = deinterleave(args, 0, 2) - .collect { case ValEx(TlaStr(a)) => a } - val _, fieldTypes = deinterleave(argTypes, 1, 2) - val elemTypes = argTypes.collect { case FinSetT(t) => t } - if (elemTypes.size < fieldTypes.size) { - error(ex, "Only finite sets of records are supported in [a: A, ..., z: Z]") - } - assert(fieldNames.length == fieldTypes.length) - FinSetT(RecordT(SortedMap(fieldNames.zip(elemTypes): _*))) - - case ex @ OperEx(op @ TlaSetOper.powerset, _) => - argTypes match { - case Seq(FinSetT(elemT)) => - FinSetT(FinSetT(elemT)) - - // what about SUBSET [S -> T]? - - case _ => errorUnexpected(ex, op.name, argTypes) - } - - case ex @ OperEx(TlaSetOper.times, _*) => - assert(argTypes.nonEmpty) - val elemTypes = argTypes.collect({ case FinSetT(t) => t }) // using partial functions - if (elemTypes.size < argTypes.size) { - error(ex, "Only finite sets are supported in the cross product A \\X B") - } - FinSetT(TupleT(elemTypes)) - - case ValEx(TlaBoolSet) => // BOOLEAN - FinSetT(BoolT()) - } - - private def computeSetOps(argTypes: Seq[CellT]): PartialFunction[TlaEx, CellT] = { - case ex @ OperEx(op @ TlaSetOper.union, _) => - argTypes match { - case Seq(FinSetT(FinSetT(elemT))) => - FinSetT(elemT) - - case _ => errorUnexpected(ex, op.name, argTypes) - } - - case ex @ OperEx(op @ TlaSetOper.filter, _, _, _) => - argTypes match { - case Seq(_, FinSetT(elemT), BoolT()) => - FinSetT(elemT) - - case Seq(_, PowSetT(elemT), BoolT()) => - FinSetT(elemT) // powersets are expanded - - // what about {f \in [S -> T] : ... }? - // what about {f \in [a: S, B: T] |-> ... }? - - case _ => errorUnexpected(ex, op.name, argTypes) - } - - case ex @ OperEx(op @ TlaSetOper.map, _*) => - var varType: CellT = UnknownT() - for ((tp, index) <- argTypes.tail.zipWithIndex) { - if (index % 2 == 0) { - varType = tp // save the type of the bound variable - } - if (index % 2 == 1) { - tp match { - case FinSetT(et) => - if (et != varType) { - error(ex, "Expected Set(%s) at position %d, found: %s".format(varType, index / 2, tp)) - } - - // what about {f \in [S -> T] |-> ... }? - // what about {f \in [a: S, B: T] |-> ... }? - case _ => errorUnexpected(ex, op.name, argTypes) - } - } - } - FinSetT(argTypes.head) - - case ex @ OperEx(op, _, _) if op == TlaSetOper.in || op == TlaSetOper.notin => - argTypes match { - case Seq(memT, FinSetT(elemT)) => - expectEqualTypes(ex, memT, elemT) - BoolT() - - case Seq(memT, InfSetT(elemT)) => - expectEqualTypes(ex, memT, elemT) - BoolT() - - // what about f \in [S -> T]? - // what about f \in [a: S, B: T]? - - case _ => errorUnexpected(ex, op.name, argTypes) - } - - case ex @ OperEx(op, _, _) - if op == TlaSetOper.subsetProper || op == TlaSetOper.subseteq - || op == TlaSetOper.supsetProper || op == TlaSetOper.supseteq => - argTypes match { - case Seq(FinSetT(leftT), FinSetT(rightT)) => - expectEqualTypes(ex, leftT, rightT) - BoolT() - - // what about f \in [S -> T]? - // what about f \in [a: S, B: T]? - case _ => errorUnexpected(ex, op.name, argTypes) - } - - case ex @ OperEx(op, _, _) if op == TlaSetOper.cup || op == TlaSetOper.cap || op == TlaSetOper.setminus => - argTypes match { - case Seq(FinSetT(leftT), FinSetT(rightT)) => - expectEqualTypes(ex, leftT, rightT) - FinSetT(leftT) - - case _ => errorUnexpected(ex, op.name, argTypes) - } - } - - private def computeFunCtors(argTypes: Seq[CellT]): PartialFunction[TlaEx, CellT] = { - case ex @ OperEx(TlaFunOper.tuple) => - SeqT(UnknownT()) - - case ex @ OperEx(op @ TlaFunOper.tuple, _*) => - TupleT(argTypes) - - case ex @ OperEx(op @ TlaFunOper.enum, args @ _*) => - assert(argTypes.nonEmpty) - val fieldNames = deinterleave(args, 0, 2) collect { case ValEx(TlaStr(a)) => a } - val namesTypes = deinterleave(argTypes, 0, 2) collect { case ConstT() => true } - - if (namesTypes.size != fieldNames.size) { - errorUnexpected(ex, op.name, argTypes) - } - val fieldTypes = deinterleave(argTypes, 1, 2) - assert(fieldNames.length == fieldTypes.length) - RecordT(SortedMap(fieldNames.zip(fieldTypes): _*)) - } - - private def computeFunApp(argTypes: Seq[CellT]): PartialFunction[TlaEx, CellT] = { - case ex @ OperEx(op @ TlaFunOper.app, fun, arg) => - val funType = argTypes.head - val argType = argTypes.tail.head - funType match { - case FunT(FinSetT(funArgT), funResT) if funArgT == argType => - funResT - - case SeqT(resT) if argType == IntT() => - resT - - case TupleT(elemTypes) if argType == IntT() => - // try to extract an integer from the expression - arg match { - case ValEx(TlaInt(i)) => - if (i >= 1 && i <= elemTypes.length) { - elemTypes(i.toInt - 1) // the argument is within a small range, so toInt should work - } else { - error(ex, "The tuple argument is out of range: " + i) - } - - case _ => error(ex, "Expected an integer constant as the tuple argument, found: " + arg) - } - - case RecordT(fields) if argType == ConstT() => - // try to extract a string from the expression - arg match { - case ValEx(TlaStr(s)) => - if (fields.contains(s)) { - fields(s) - } else { - error(ex, "Unexpected record field name: " + s) - } - - case _ => error(ex, "Expected a string constant as the tuple argument, found: " + arg) - } - - case _ => - errorUnexpected(ex, op.name, argTypes) - } - } - - private def computeFunOps(argTypes: Seq[CellT]): PartialFunction[TlaEx, CellT] = { - case ex @ OperEx(op, e, bindings @ _*) if op == TlaFunOper.funDef || op == TlaFunOper.recFunDef => - val resType = argTypes.head - val setTypes = deinterleave(argTypes.tail, 1, 2) - val varTypes = deinterleave(argTypes.tail, 0, 2) - if (varTypes.length != setTypes.length) { - errorUnexpected(ex, op.name, argTypes) - } else { - val elemTypes = setTypes.collect { case FinSetT(et) => et } - if (elemTypes.length != setTypes.length) { - // wrong types were passed - errorUnexpected(ex, op.name, argTypes) - } - if (setTypes.length == 1) { - // a single-argument function - FunT(setTypes.head, resType) - } else { - // a multi-argument function, which means it receives a Cartesian product - FunT(FinSetT(TupleT(elemTypes)), resType) - } - } - - case ex @ OperEx(TlaFunOper.recFunRef) => - // no annotation met, produce an error - error(ex, - "Reference to a recursive function needs type annotation, see:" + - " https://apalache.informal.systems/docs/apalache/principles.html#rec-fun") - - case ex @ OperEx(op @ TlaFunOper.except, e, bindings @ _*) => - val funType = argTypes.head - // In principle, we could just return the function itself. - // But we also check the argument types to be on a well-typed side. - val indices = deinterleave(bindings, 0, 2) // the expressions - val indexTypes = deinterleave(argTypes.tail, 0, 2) - val valueTypes = deinterleave(argTypes.tail, 1, 2) - funType match { - case FunT(_, _) => - val (argT, resT) = - funType match { - // an argument to EXCEPT is always wrapped into a tuple - case FunT(FinSetT(elemT), rt) => (TupleT(Seq(elemT)), rt) - case _ => error(ex, "Expected a function type, found: " + funType) - } - for (idx <- indexTypes) { - if (idx != argT) { - error(ex, "Expected an index of type TupleT(%s), found: %s".format(argT, idx)) - } - } - for (valT <- valueTypes) { - if (valT != resT) { - error(ex, "Expected a result of type %s, found: %s".format(resT, valT)) - } - } - - case rec @ RecordT(_) => - for (idx <- indexTypes) { - if (idx != TupleT(Seq(ConstT()))) { - error(ex, "Expected an index of type %s, found: %s".format(ConstT(), idx)) - } - } - for ((index, valT) <- indices.zip(valueTypes)) { - index match { - case OperEx(TlaFunOper.tuple, ValEx(TlaStr(key))) => - if (valT != rec.fields(key)) { - error(ex, "Expected an index of type TupleT(%s), found: %s".format(rec.fields(key), valT)) - } - - case _ => - error(ex, s"Expected a record key, found: $index") - } - - } - - case tup @ TupleT(Seq(argTypes @ _*)) => - for (idx <- indexTypes) { - if (idx != TupleT(Seq(IntT()))) { - error(ex, "Expected an index of type TupleT(%s), found: %s".format(IntT(), idx)) - } - } - for ((argT, valT) <- argTypes.zip(valueTypes)) { - if (argT != valT) { - error(ex, "Expected a value of type %s, found: %s".format(argT, valT)) - } - } - - case _ => - error(ex, "Expected a function, a record, or a tuple") - } - - funType - - case ex @ OperEx(TlaFunOper.domain, fun) => - argTypes.head match { - case FunT(domT, _) => domT - case TupleT(_) => FinSetT(IntT()) - case RecordT(_) => FinSetT(ConstT()) - case SeqT(_) => FinSetT(IntT()) - case _ => error(ex, "Unexpected type of DOMAIN argument: " + ex) - } - } - - private def computeIntOps(argTypes: Seq[CellT]): PartialFunction[TlaEx, CellT] = { - case ex @ OperEx(op, _) if op == TlaArithOper.uminus => - assert(argTypes.length == 1) - expectType(IntT(), ex, argTypes.head) - IntT() - - case ex @ OperEx(TlaArithOper.dotdot, _, _) => - assert(argTypes.length == 2) - argTypes.foreach(expectType(IntT(), ex, _)) - FinSetT(IntT()) - - case ex @ OperEx(op, _, _) - if op == TlaArithOper.plus || op == TlaArithOper.minus - || op == TlaArithOper.mult || op == TlaArithOper.div || op == TlaArithOper.mod || op == TlaArithOper.exp => - assert(argTypes.length == 2) - argTypes.foreach(expectType(IntT(), ex, _)) - IntT() - - case ex @ OperEx(op, _, _) - if op == TlaArithOper.lt || op == TlaArithOper.gt || op == TlaArithOper.le || op == TlaArithOper.ge => - assert(argTypes.length == 2) - argTypes.foreach(expectType(IntT(), ex, _)) - BoolT() - - case ex @ OperEx(op, _*) if op == TlaArithOper.sum || op == TlaArithOper.prod => - argTypes.foreach(expectType(IntT(), ex, _)) - IntT() - } - - private def computeBoolOps(argTypes: Seq[CellT]): PartialFunction[TlaEx, CellT] = { - case ex @ OperEx(TlaBoolOper.not, _) => - assert(argTypes.length == 1) - expectType(BoolT(), ex, argTypes.head) - BoolT() - - case ex @ OperEx(op, _, _) if op == TlaBoolOper.implies || op == TlaBoolOper.equiv => - assert(argTypes.length == 2) - argTypes.foreach(expectType(BoolT(), ex, _)) - BoolT() - - case ex @ OperEx(op, _*) if op == TlaBoolOper.and || op == TlaBoolOper.or => - argTypes.foreach(expectType(BoolT(), ex, _)) - BoolT() - - case ex @ OperEx(op, x, set, pred) if op == TlaBoolOper.forall || op == TlaBoolOper.exists => - val xType = argTypes.head - val setType = argTypes.tail.head - val predType = argTypes.tail.tail.head - expectType(BoolT(), pred, predType) - setType match { - case FinSetT(elemT) => - expectType(elemT, x, xType) - - case InfSetT(elemT) => - expectType(elemT, x, xType) - - case _ => - errorUnexpected(set, op.name, argTypes) - } - BoolT() - } - - private def computeControlOps(argTypes: Seq[CellT]): PartialFunction[TlaEx, CellT] = { - case ex @ OperEx(TlaControlOper.ifThenElse, predEx, thenEx, elseEx) => - assert(argTypes.length == 3) - expectType(BoolT(), predEx, argTypes.head) - val leftType = argTypes.tail.head - expectEqualTypes(ex, argTypes.tail: _*) - leftType - - case ex @ OperEx(TlaControlOper.caseNoOther, _*) => - val guards = argTypes.zipWithIndex.collect { case (a, i) if i % 2 == 0 => a } - val actions = argTypes.zipWithIndex.collect { case (a, i) if i % 2 == 1 => a } - guards.foreach(expectType(BoolT(), ex, _)) - expectEqualTypes(ex, actions: _*) - actions.head - - case ex @ OperEx(TlaControlOper.caseWithOther, _*) => - val guards = argTypes.zipWithIndex.collect { case (a, i) if i % 2 == 1 => a } - val actions = argTypes.zipWithIndex.collect { case (a, i) if i % 2 == 0 => a } - guards.foreach(expectType(BoolT(), ex, _)) - expectEqualTypes(ex, actions: _*) - actions.head - - case ex @ LetInEx(_, _*) => - // Can we really type-check anything here? We would need to analyze the let bindings. - argTypes.head - } - - private def computeFiniteSetOps(argTypes: Seq[CellT]): PartialFunction[TlaEx, CellT] = { - case ex @ OperEx(op, _) if op == TlaFiniteSetOper.isFiniteSet || op == TlaFiniteSetOper.cardinality => - assert(argTypes.length == 1) - argTypes.head match { - case FinSetT(_) => - if (op == TlaFiniteSetOper.isFiniteSet) - BoolT() - else - IntT() - - case _ => - errorUnexpected(ex, op.name, argTypes) - } - } - - private def computeSeqOps(argTypes: Seq[CellT]): PartialFunction[TlaEx, CellT] = { - case ex @ OperEx(op, _) if op == TlaSeqOper.head || op == TlaSeqOper.tail || op == TlaSeqOper.len => - assert(argTypes.length == 1) - argTypes.head match { - case SeqT(elemT) => - if (op == TlaSeqOper.head) - elemT - else if (op == TlaSeqOper.tail) - SeqT(elemT) - else IntT() // len - - case _ => - errorUnexpected(ex, op.name, argTypes) - } - - case ex @ OperEx(op @ TlaSeqOper.append, _, argEx) => - assert(argTypes.length == 2) - argTypes.head match { - case SeqT(elemT) => - expectType(elemT, argEx, argTypes.tail.head) - SeqT(elemT) - - case _ => - errorUnexpected(ex, op.name, argTypes) - } - - case ex @ OperEx(op @ TlaSeqOper.concat, lex, rex) => - assert(argTypes.length == 2) - argTypes.head match { - case SeqT(elemT) => - expectType(SeqT(elemT), rex, argTypes.tail.head) - SeqT(elemT) - - case _ => - errorUnexpected(ex, op.name, argTypes) - } - - case ex @ OperEx(op @ TlaSeqOper.subseq, seq, start, end) => - assert(argTypes.length == 3) - argTypes.head match { - case SeqT(elemT) => - expectType(IntT(), start, argTypes.tail.head) - expectType(IntT(), end, argTypes.tail.tail.head) - SeqT(elemT) - - case _ => - errorUnexpected(ex, op.name, argTypes) - } - - case ex @ OperEx(op @ TlaSeqOper.selectseq, seq, pred) => - // pred should be a second-level operator. How would we implement it here? - throw new NotImplementedError("Type construction for Sequence.selectseq cannot be implemented.") - } - - private def computeMiscOps(argTypes: Seq[CellT]): PartialFunction[TlaEx, CellT] = { - case ex @ OperEx(BmcOper.`skolem`, _) => - BoolT() - - case ex @ OperEx(BmcOper.`constCard`, _) => - BoolT() - - case ex @ OperEx(BmcOper.expand, _) => - argTypes.head - - case ex @ OperEx(TlaOper.label, args @ _*) => - for ((a, t) <- args.tail.zip(argTypes.tail)) expectType(ConstT(), a, t) - argTypes.head - - case ex @ OperEx(TlcOper.assert, expr, msg) => - val exprType = argTypes.head - val msgType = argTypes.tail.head - expectType(BoolT(), expr, exprType) - expectType(ConstT(), msg, msgType) - BoolT() - - case ex @ OperEx(TlcOper.print, _, msg) => - // an expression can be of any type - val msgType = argTypes.tail.head - expectType(ConstT(), msg, msgType) - BoolT() - - case ex @ OperEx(TlcOper.printT, msg) => - val msgType = argTypes.head - expectType(ConstT(), msg, msgType) - BoolT() - - case ex @ OperEx(TlcOper.colonGreater, _, _) => - TupleT(argTypes) // a :> b is simply <> in our type system - - case ex @ OperEx(TlcOper.atat, _, _) => - argTypes.head match { - case funT @ FunT(FinSetT(argT), resT) => - argTypes.tail.head match { - case TupleT(Seq(at, rt)) => - expectEqualTypes(ex, argT, at) - expectEqualTypes(ex, resT, rt) - funT - - case tt @ _ => - expectType(TupleT(Seq(argT, resT)), ex, tt) - funT - } - - case _ => - errorUnexpected(ex, TlcOper.atat.name, argTypes) - } - - case ex @ OperEx(BmcOper.withType, _*) => - throw new IllegalStateException("The type annotation must have been saved by inferAndSave: " + ex) - } - - private def expectType(expectedType: CellT, ex: TlaEx, exType: CellT): Unit = { - if (exType != expectedType) { - error(ex, "Expected type %s, found %s".format(expectedType, exType)) - } - } - - private def expectEqualTypes(ex: TlaEx, types: CellT*): Unit = { - if (types.nonEmpty) { - val firstType = types.head - - if (types.tail.exists(_ != firstType)) { - error(ex, "Expected equal types: %s".format(types.mkString(" and "))) - } - } - } - - private def errorUnexpected(ex: TlaEx, opname: String, argTypes: Seq[CellT]): CellT = { - error(ex, "Unexpected types for %s: %s".format(opname, argTypes.mkString(", "))) - } - - private def error(ex: TlaEx, message: String): CellT = { - _typeErrors :+= new TypeInferenceError(ex, message) - UnknownT() - } - - private def errorThenNone(ex: TlaEx, message: String): Option[CellT] = { - error(ex, message) - None - } - - private def notImplemented: PartialFunction[TlaEx, CellT] = { case ex => - throw new NotImplementedError("Type construction for %s is not implemented. Report a bug!".format(ex)) - } - - /** - * Get a subsequence of elements whose indices satisfy the predicate: index % base == group. - * - * @param s sequence - * @param group the group number (from 0 to base - 1) - * @param base the divider to use in the modulo operation - * @tparam T element type - * @return the subsequence of s s.t. index % base == group - */ - private def deinterleave[T](s: Seq[T], group: Int, base: Int): Seq[T] = { - s.zipWithIndex.collect { case (e, i) if i % base == group => e } - } -} diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/types/eager/TrivialTypeSnapshot.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/types/eager/TrivialTypeSnapshot.scala deleted file mode 100644 index 3a0ed04dd4..0000000000 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/types/eager/TrivialTypeSnapshot.scala +++ /dev/null @@ -1,14 +0,0 @@ -package at.forsyte.apalache.tla.bmcmt.types.eager - -import at.forsyte.apalache.tla.bmcmt.types.CellT -import at.forsyte.apalache.tla.lir.UID - -import scala.collection.immutable.{Map, SortedMap} - -/** - * A snapshot of TrivialTypeFinder that can be recovered into a new TrivialTypeFinder. - * All intermediate context are squashed into a single context. - * - * @author Igor Konnov - */ -class TrivialTypeSnapshot(val typeAnnotations: Map[UID, CellT], val varTypes: SortedMap[String, CellT]) {} diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/types/package.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/types/package.scala index 2bd09de176..d5025c5872 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/types/package.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/types/package.scala @@ -1,7 +1,9 @@ package at.forsyte.apalache.tla.bmcmt -import at.forsyte.apalache.tla.lir.NullEx -import at.forsyte.apalache.tla.lir.io.UTFPrinter +import at.forsyte.apalache.tla.lir.{ + BoolT1, ConstT1, FunT1, IntT1, NullEx, OperT1, RealT1, RecT1, SeqT1, SetT1, SparseTupT1, StrT1, TlaType1, TupT1, + TypeTag, TypingException, VarT1 +} import scala.collection.immutable.SortedMap @@ -33,6 +35,13 @@ package object types { */ val signature: String + /** + * Convert the cell type into TlaType1. Note that + * + * @return + */ + def toTlaType1: TlaType1 + /** * Compute a unifier of two types. * @@ -138,6 +147,52 @@ package object types { } } + object CellT { + + /** + * Convert a TlaType1 to a cell type. + */ + def fromType1(tt: TlaType1): CellT = { + tt match { + case IntT1() => IntT() + case StrT1() => ConstT() + case BoolT1() => BoolT() + case ConstT1(_) => ConstT() // this should change in https://github.com/informalsystems/apalache/issues/570 + case RealT1() => + throw new TypingException("Unsupported type RealT1") + + case VarT1(_) => + // type variables should have been resolved by operator inlining and type checking + throw new TypingException("Unexpected type VarT1") + + case SetT1(elem) => FinSetT(fromType1(elem)) + case SeqT1(elem) => SeqT(fromType1(elem)) + case FunT1(arg, res) => FunT(FinSetT(fromType1(arg)), fromType1(res)) + case TupT1(elems @ _*) => TupleT(elems.map(fromType1)) + case RecT1(fieldTypes) => RecordT(fieldTypes.mapValues(fromType1)) + + case SparseTupT1(_) => + // sparse tuple can only appear in operator arguments, which must have been inlined + throw new TypingException("Unexpected type SparseTupT1") + + case OperT1(_, _) => + // all operators are inlined + throw new TypingException("Unexpected operator type OperT1") + } + } + + /** + * Convert a type tag to a cell type + * + * @param typeTag a type tag + * @return the corresponding cell type, if the tag has type Typed(_: TlaType1); otherwise, throw an exception + */ + def fromTypeTag(typeTag: TypeTag): CellT = { + fromType1(TlaType1.fromTypeTag(typeTag)) + } + + } + // Jure @ 13.09.18: Suggestion: Replace all case class X( [no args] ) with object X ? // // Igor @ 20.12.18: This is called preliminary optimization. I would imagine that the scala @@ -161,6 +216,10 @@ package object types { override val signature: String = "u" override val toString: String = "Unknown" + + override def toTlaType1: TlaType1 = { + ConstT1("UNKNOWN") + } } /** @@ -172,57 +231,46 @@ package object types { override val signature: String = "E" override val toString: String = "FailPred" + + override def toTlaType1: TlaType1 = { + ConstT1("FAILPRED") + } } /** * A cell constant, that is, just a name that expresses string constants in TLA+. */ sealed case class ConstT() extends CellT with Serializable { - - /** - * Produce a short signature that uniquely describes the type (up to unification), - * similar to Java's signature mangling. If one type can be unified to another, - * e.g., records, they have the same signature. - * - * @return a short signature that uniquely characterizes this type up to unification - */ override val signature: String = "str" override val toString: String = "Const" + + override def toTlaType1: TlaType1 = { + // in the new type system, we have the string type for such constants + StrT1() + } } /** * A Boolean cell type. */ sealed case class BoolT() extends CellT with Serializable { - - /** - * Produce a short signature that uniquely describes the type (up to unification), - * similar to Java's signature mangling. If one type can be unified to another, - * e.g., records, they have the same signature. - * - * @return a short signature that uniquely characterizes this type up to unification - */ override val signature: String = "b" override val toString: String = "Bool" + + override def toTlaType1: TlaType1 = BoolT1() } /** * An integer cell type. */ sealed case class IntT() extends CellT with Serializable { - - /** - * Produce a short signature that uniquely describes the type (up to unification), - * similar to Java's signature mangling. If one type can be unified to another, - * e.g., records, they have the same signature. - * - * @return a short signature that uniquely characterizes this type up to unification - */ override val signature: String = "i" override val toString: String = "Int" + + override def toTlaType1: TlaType1 = IntT1() } /** @@ -231,17 +279,11 @@ package object types { * @param elemType the elements type */ sealed case class FinSetT(elemType: CellT) extends CellT with Serializable { - - /** - * Produce a short signature that uniquely describes the type (up to unification), - * similar to Java's signature mangling. If one type can be unified to another, - * e.g., records, they have the same signature. - * - * @return a short signature that uniquely characterizes this type up to unification - */ override val signature: String = s"S${elemType.signature}" override val toString: String = s"FinSet[$elemType]" + + override def toTlaType1: TlaType1 = SetT1(elemType.toTlaType1) } /** @@ -251,17 +293,11 @@ package object types { * @param elemType the elements type */ sealed case class InfSetT(elemType: CellT) extends CellT with Serializable { - - /** - * Produce a short signature that uniquely describes the type (up to unification), - * similar to Java's signature mangling. If one type can be unified to another, - * e.g., records, they have the same signature. - * - * @return a short signature that uniquely characterizes this type up to unification - */ override val signature: String = s"Z${elemType.signature}" override val toString: String = s"InfSet[$elemType]" + + override def toTlaType1: TlaType1 = SetT1(elemType.toTlaType1) } /** @@ -271,16 +307,12 @@ package object types { */ sealed case class PowSetT(domType: CellT) extends CellT with Serializable { require(domType.isInstanceOf[FinSetT]) // currently, we support only PowSetT(FinSetT(_)) - /** - * Produce a short signature that uniquely describes the type (up to unification), - * similar to Java's signature mangling. If one type can be unified to another, - * e.g., records, they have the same signature. - * - * @return a short signature that uniquely characterizes this type up to unification - */ + override val signature: String = s"P${domType.signature}" override val toString: String = s"PowSet[$domType]" + + override def toTlaType1: TlaType1 = SetT1(domType.toTlaType1) } /** @@ -293,24 +325,22 @@ package object types { * @param resultType result type (not the co-domain!) */ sealed case class FunT(domType: CellT, resultType: CellT) extends CellT with Serializable { - - /** - * Produce a short signature that uniquely describes the type (up to unification), - * similar to Java's signature mangling. If one type can be unified to another, - * e.g., records, they have the same signature. - * - * @return a short signature that uniquely characterizes this type up to unification - */ override val signature: String = s"f${domType.signature}_${resultType.signature}" val argType: CellT = domType match { - case FinSetT(et) => et - case PowSetT(dt) => dt - case CrossProdT(args) => TupleT(args map { _.elemType }) - case _ => throw new TypeException(s"Unexpected domain type $domType", NullEx) + case FinSetT(et) => et + case PowSetT(dt) => dt + case _ => throw new TypeException(s"Unexpected domain type $domType", NullEx) } override val toString: String = s"Fun[$domType, $resultType]" + + override def toTlaType1: TlaType1 = { + domType.toTlaType1 match { + case SetT1(elemType) => FunT1(elemType, resultType.toTlaType1) + case tt => throw new TypingException("Unexpected function domain type: " + tt) + } + } } /** @@ -325,13 +355,6 @@ package object types { && (cdmType.isInstanceOf[FinSetT] || cdmType.isInstanceOf[PowSetT] || cdmType.isInstanceOf[FinFunSetT] || cdmType.isInstanceOf[InfSetT])) - /** - * Produce a short signature that uniquely describes the type (up to unification), - * similar to Java's signature mangling. If one type can be unified to another, - * e.g., records, they have the same signature. - * - * @return a short signature that uniquely characterizes this type up to unification - */ override val signature: String = s"F%${domType.signature}_%${cdmType.signature}" def argType(): CellT = domType match { @@ -347,6 +370,13 @@ package object types { } override val toString: String = s"FinFunSet[$domType, $cdmType]" + + override def toTlaType1: TlaType1 = { + (domType.toTlaType1, cdmType.toTlaType1) match { + case (SetT1(arg), SetT1(res)) => SetT1(FunT1(arg, res)) + case (dt, cdt) => throw new TypingException(s"Unexpected domain type $dt and result type $cdt") + } + } } /** @@ -356,15 +386,13 @@ package object types { */ sealed case class TupleT(args: Seq[CellT]) extends CellT with Serializable { - /** - * Produce a short signature that uniquely describes the type (up to unification), - * similar to Java's signature mangling. - * - * @return a short signature that uniquely characterizes this type up to unification - */ override val signature: String = s"T_${args.map(_.signature).mkString("_")}" override val toString: String = s"Tuple[${args.map(_.toString).mkString(", ")}]" + + override def toTlaType1: TlaType1 = { + TupT1(args.map(_.toTlaType1): _*) + } } /** @@ -376,36 +404,11 @@ package object types { */ sealed case class SeqT(res: CellT) extends CellT with Serializable { - /** - * Produce a short signature that uniquely describes the type (up to unification), - * similar to Java's signature mangling. - * - * @return a short signature that uniquely characterizes this type up to unification - */ override val signature: String = s"Q_${res.signature}" override val toString: String = s"Seq[$res]" - } - /** - * The type for a cross product, e.g., FinSetT(A) |X FinSetT(B). - * - * FIXME: this type should disappear in the future, - * as CrossProdT(FinSetT(A), FinSetT(B)) = FinSetT(TupleT(A, B)) - * - * @param args - */ - @deprecated("Never been used") - sealed case class CrossProdT(args: Seq[FinSetT]) extends CellT with Serializable { - - /** - * Produce a short signature that uniquely describes the type (up to unification), - * similar to Java's signature mangling. If one type can be unified to another, - * e.g., records, they have the same signature. - * - * @return a short signature that uniquely characterizes this type up to unification - */ - override val signature: String = args map { _.signature } mkString UTFPrinter.m_times + override def toTlaType1: TlaType1 = SeqT1(res.toTlaType1) } /** @@ -414,46 +417,14 @@ package object types { * @param fields a map of fields and their types */ sealed case class RecordT(fields: SortedMap[String, CellT]) extends CellT with Serializable { - - /** - * Produce a short signature that uniquely describes the type (up to unification), - * similar to Java's signature mangling. If one type can be unified to another, - * e.g., records, they have the same signature. - * - * As records having different domains can be unified, we do not include the records arguments in the signature. - * - * @return a short signature that uniquely characterizes this type up to unification - */ override val signature: String = "R" override val toString: String = s"Record[${fields.map { case (k, v) => "\"" + k + "\" -> " + v } mkString ", "}]" - } - // FIXME: Igor @ 20.12.2018: Do we still need this type? - sealed case class TypeParam(s: String) extends CellT { - - /** - * Produce a short signature that uniquely describes the type (up to unification), - * similar to Java's signature mangling. If one type can be unified to another, - * e.g., records, they have the same signature. - * - * @return a short signature that uniquely characterizes this type up to unification - */ - override val signature = s"P${s}" - } - - // FIXME: Igor @ 20.12.2018: Do we still need this type? - sealed case class OptT(elementType: CellT) extends CellT { - - /** - * Produce a short signature that uniquely describes the type (up to unification), - * similar to Java's signature mangling. If one type can be unified to another, - * e.g., records, they have the same signature. - * - * @return a short signature that uniquely characterizes this type up to unification - */ - override val signature: String = s"O${elementType.signature}" + override def toTlaType1: TlaType1 = { + RecT1(fields.mapValues(_.toTlaType1)) + } } /** diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/RewriterBase.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/RewriterBase.scala index 29515f7d8b..d8c8096933 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/RewriterBase.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/RewriterBase.scala @@ -3,7 +3,6 @@ package at.forsyte.apalache.tla.bmcmt import java.io.{PrintWriter, StringWriter} import at.forsyte.apalache.tla.bmcmt.smt.{PreproSolverContext, SolverConfig, SolverContext, Z3SolverContext} -import at.forsyte.apalache.tla.bmcmt.types.eager.TrivialTypeFinder import at.forsyte.apalache.tla.lir.convenience.tla import at.forsyte.apalache.tla.lir.UntypedPredefs._ import org.scalatest.{BeforeAndAfterEach, FunSuite} @@ -26,7 +25,7 @@ class RewriterBase extends FunSuite with BeforeAndAfterEach { } protected def createWithoutCache(): SymbStateRewriter = { - new SymbStateRewriterImpl(solverContext, new TrivialTypeFinder()) + new SymbStateRewriterImpl(solverContext) } protected def assertUnsatOrExplain(rewriter: SymbStateRewriter, state: SymbState): Unit = { diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestCherryPick.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestCherryPick.scala index 88d1e3e536..67bab5b1a9 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestCherryPick.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestCherryPick.scala @@ -1,38 +1,44 @@ package at.forsyte.apalache.tla.bmcmt -import at.forsyte.apalache.tla.bmcmt.rules.aux.{CherryPick, Oracle, OracleFactory, OracleHelper} +import at.forsyte.apalache.tla.bmcmt.rules.aux.{CherryPick, Oracle, OracleFactory} import at.forsyte.apalache.tla.bmcmt.types._ +import at.forsyte.apalache.tla.lir.TypedPredefs._ import at.forsyte.apalache.tla.lir._ -import at.forsyte.apalache.tla.lir.convenience.tla -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.lir.convenience.tla._ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestCherryPick extends RewriterBase with TestingPredefs { - private def emptySetWithType(elemT: CellT): TlaEx = - tla.withType(tla.enumSet(), AnnotationParser.toTla(FinSetT(elemT))) + private val types = Map( + "b" -> BoolT1(), + "i" -> IntT1(), + "i_to_i" -> FunT1(IntT1(), IntT1()), + "I" -> SetT1(IntT1()), + "II" -> SetT1(SetT1(IntT1())), + "Qi" -> SeqT1(IntT1()), + "ii" -> TupT1(IntT1(), IntT1()), + "ri" -> RecT1("a" -> IntT1()), + "rii" -> RecT1("a" -> IntT1(), "b" -> IntT1()), + "riis" -> RecT1("a" -> IntT1(), "b" -> IntT1(), "c" -> StrT1()), + "Riis" -> SetT1(RecT1("a" -> IntT1(), "b" -> IntT1(), "c" -> StrT1())), + "i_ii" -> TupT1(IntT1(), TupT1(IntT1(), IntT1())) + ) private def assertEqWhenChosen(rewriter: SymbStateRewriter, state: SymbState, oracle: Oracle, position: Int, expected: TlaEx): SymbState = { rewriter.push() solverContext.assertGroundExpr(oracle.whenEqualTo(state, position)) - val ns = rewriter.rewriteUntilDone(state.setRex(tla.eql(state.ex, expected))) - rewriter.push() - solverContext.assertGroundExpr(ns.ex) - assert(solverContext.sat()) - rewriter.pop() - rewriter.push() - solverContext.assertGroundExpr(tla.not(ns.ex)) - assertUnsatOrExplain(rewriter, ns) - rewriter.pop() + val eq = eql(state.ex, expected).typed(BoolT1()) + assertTlaExAndRestore(rewriter, state.setRex(eq)) + rewriter.pop() state } - test("""CHERRY-PICK {1, 2, 2} ~~> $B$k""") { + test("""CHERRY-PICK {1, 2, 2}""") { val rewriter = create() - var state = new SymbState(tla.bool(true), arena, Binding()) + var state = new SymbState(bool(true).typed(BoolT1()), arena, Binding()) // introduce an oracle that tells us which element to pick val (oracleState, oracle) = new OracleFactory(rewriter).newConstOracle(state, 3) state = oracleState @@ -41,7 +47,7 @@ class TestCherryPick extends RewriterBase with TestingPredefs { // introduce integer cells directly arena = state.arena.appendCell(IntT()) val cell = arena.topCell - solverContext.assertGroundExpr(tla.eql(cell.toNameEx, tla.int(i))) + solverContext.assertGroundExpr(eql(cell.toNameEx ? "i", int(i)).typed(types, "b")) state = state.setArena(arena) cell } @@ -51,20 +57,20 @@ class TestCherryPick extends RewriterBase with TestingPredefs { .pickBasic(IntT(), state, oracle, intCells, state.arena.cellFalse().toNameEx) assert(solverContext.sat()) - assertEqWhenChosen(rewriter, pickedState, oracle, 0, tla.int(1)) - assertEqWhenChosen(rewriter, pickedState, oracle, 1, tla.int(2)) - assertEqWhenChosen(rewriter, pickedState, oracle, 2, tla.int(2)) + assertEqWhenChosen(rewriter, pickedState, oracle, 0, int(1).typed()) + assertEqWhenChosen(rewriter, pickedState, oracle, 1, int(2).typed()) + assertEqWhenChosen(rewriter, pickedState, oracle, 2, int(2).typed()) } - test("""CHERRY-PICK {<<1, 2>>, <<3, 4>>} ~~> $B$k""") { + test("""CHERRY-PICK {<<1, 2>>, <<3, 4>>}""") { val rewriter = create() - var state = new SymbState(tla.bool(true), arena, Binding()) + var state = new SymbState(bool(true).typed(), arena, Binding()) // introduce an oracle that tells us which element to pick val (oracleState, oracle) = new OracleFactory(rewriter).newConstOracle(state, 2) state = oracleState def mkTuple(i: Int, j: Int): ArenaCell = { - state = rewriter.rewriteUntilDone(state.setRex(tla.tuple(tla.int(i), tla.int(j)))) + state = rewriter.rewriteUntilDone(state.setRex(tuple(int(i), int(j)).typed(types, "ii"))) state.asCell } @@ -77,15 +83,15 @@ class TestCherryPick extends RewriterBase with TestingPredefs { assertEqWhenChosen(rewriter, state, oracle, 1, tuples(1).toNameEx) } - test("""CHERRY-PICK {<<1, <<2, 3>> >>, <<3, <<4, 5>> >>} ~~> $B$k""") { + test("""CHERRY-PICK {<<1, <<2, 3>> >>, <<3, <<4, 5>> >>}""") { val rewriter = create() - var state = new SymbState(tla.bool(true), arena, Binding()) + var state = new SymbState(bool(true).typed(), arena, Binding()) // introduce an oracle that tells us which element to pick val (oracleState, oracle) = new OracleFactory(rewriter).newConstOracle(state, 2) state = oracleState def mkTuple(i: Int, j: Int, k: Int): ArenaCell = { - state = rewriter.rewriteUntilDone(state.setRex(tla.tuple(tla.int(i), tla.tuple(tla.int(j), tla.int(k))))) + state = rewriter.rewriteUntilDone(state.setRex(tuple(int(i), tuple(int(j), int(k)) ? "ii").typed(types, "i_ii"))) state.asCell } @@ -100,15 +106,15 @@ class TestCherryPick extends RewriterBase with TestingPredefs { test("""CHERRY-PICK-SEQ {<<1, 2>>, <<3, 4>>}""") { val rewriter = create() - var state = new SymbState(tla.bool(true), arena, Binding()) + var state = new SymbState(bool(true).typed(BoolT1()), arena, Binding()) // introduce an oracle that tells us which element to pick val (oracleState, oracle) = new OracleFactory(rewriter).newConstOracle(state, 2) state = oracleState def mkSeq(args: Int*): ArenaCell = { - val tuple = tla.tuple(args map tla.int: _*) - val annot = tla.withType(tuple, AnnotationParser.toTla(SeqT(IntT()))) - state = rewriter.rewriteUntilDone(state.setRex(annot)) + val tup = tuple(args map int: _*) + .typed(types, "Qi") + state = rewriter.rewriteUntilDone(state.setRex(tup)) state.asCell } @@ -120,15 +126,17 @@ class TestCherryPick extends RewriterBase with TestingPredefs { assertEqWhenChosen(rewriter, state, oracle, 1, seqs(1).toNameEx) } - test("""CHERRY-PICK {[a |-> 1, b |-> 2], [a |-> 3, b |-> 4]} ~~> $B$k""") { + test("""CHERRY-PICK {[a |-> 1, b |-> 2], [a |-> 3, b |-> 4]}""") { val rewriter = create() - var state = new SymbState(tla.bool(true), arena, Binding()) + var state = new SymbState(bool(true).typed(), arena, Binding()) // introduce an oracle that tells us which element to pick val (oracleState, oracle) = new OracleFactory(rewriter).newConstOracle(state, 2) state = oracleState def mkRecord(i: Int, j: Int): ArenaCell = { - state = rewriter.rewriteUntilDone(state.setRex(tla.enumFun(tla.str("a"), tla.int(i), tla.str("b"), tla.int(j)))) + val rec = enumFun(str("a"), int(i), str("b"), int(j)) + .typed(types, "rii") + state = rewriter.rewriteUntilDone(state.setRex(rec)) state.asCell } @@ -141,15 +149,75 @@ class TestCherryPick extends RewriterBase with TestingPredefs { assertEqWhenChosen(rewriter, state, oracle, 1, records(1).toNameEx) } - test("""CHERRY-PICK { {1, 2}, {3, 4} } ~~> $B$k""") { + test("""CHERRY-PICK [a |-> 1, b |-> 2] or [a |-> 3]""") { + // After switching to Snowcat, we allow sets to mix records of compatible types. + // The old encoding was always introducing spurious fields for all records, as it was extending the records. + val rec1 = enumFun(str("a"), int(1), str("b"), int(2)) + .typed(types, "rii") + val rec2 = enumFun(str("a"), int(1)) + .typed(types, "ri") + + // introduce an oracle that tells us which element to pick + val rewriter = create() + var state = new SymbState(bool(true).typed(), arena, Binding()) + val (oracleState, oracle) = new OracleFactory(rewriter).newConstOracle(state, 2) + state = oracleState + state = rewriter.rewriteUntilDone(state.setRex(rec1)) + val rec1Cell = state.asCell + state = rewriter.rewriteUntilDone(state.setRex(rec2)) + val rec2Cell = state.asCell + + val recordCellType = CellT.fromType1(types("riis")) + state = new CherryPick(rewriter).pickRecord(recordCellType, state, oracle, Seq(rec1Cell, rec2Cell), + state.arena.cellFalse().toNameEx) + assert(solverContext.sat()) + + assertEqWhenChosen(rewriter, state, oracle, 0, rec1Cell.toNameEx) + assertEqWhenChosen(rewriter, state, oracle, 1, rec2Cell.toNameEx) + } + + test("""CHERRY-PICK {[a |-> 1, b |-> 2], [a |-> 3]}""") { + // After switching to Snowcat, we allow sets to mix records of compatible types. + // The old encoding was always introducing spurious fields for all records, as it was extending the records. + val rec1 = enumFun(str("a"), int(1), str("b"), int(2)) + .typed(types, "ri") + val rec2 = enumFun(str("a"), int(1)) + .typed(types, "rii") + + // introduce an oracle that tells us which element to pick + val rewriter = create() + var state = new SymbState(bool(true).typed(), arena, Binding()) + state = rewriter.rewriteUntilDone(state.setRex(rec1)) + val rec1Cell = state.asCell + state = rewriter.rewriteUntilDone(state.setRex(rec2)) + val rec2Cell = state.asCell + val set = enumSet(rec1Cell.toNameEx ? "riis", rec2Cell.toNameEx ? "riis") + .typed(types, "Riis") + state = rewriter.rewriteUntilDone(state.setRex(set)) + val setCell = state.asCell + + state = new CherryPick(rewriter).pick(setCell, state, bool(false).typed()) + assert(solverContext.sat()) + val result = state.asCell + // check that the result is equal to one of the records and nothing else + val eq1 = eql(result.toNameEx ? "riis", rec1Cell.toNameEx ? "riis") ? "b" + val eq2 = eql(result.toNameEx ? "riis", rec2Cell.toNameEx ? "riis") ? "b" + val eq1or2 = or(eq1, eq2) + .typed(types, "b") + assertTlaExAndRestore(rewriter, state.setRex(eq1or2)) + } + + test("""CHERRY-PICK { {1, 2}, {3, 4} }""") { val rewriter = create() - var state = new SymbState(tla.bool(true), arena, Binding()) + var state = new SymbState(bool(true).typed(), arena, Binding()) // introduce an oracle that tells us which element to pick val (oracleState, oracle) = new OracleFactory(rewriter).newConstOracle(state, 2) state = oracleState def mkSet(i: Int, j: Int): ArenaCell = { - state = rewriter.rewriteUntilDone(state.setRex(tla.enumSet(tla.int(i), tla.int(j)))) + val set = enumSet(int(i), int(j)) + .typed(types, "I") + state = rewriter.rewriteUntilDone(state.setRex(set)) state.asCell } @@ -163,17 +231,17 @@ class TestCherryPick extends RewriterBase with TestingPredefs { test("""CHERRY-PICK { {1, 2}, {} }""") { val rewriter = create() - var state = new SymbState(tla.bool(true), arena, Binding()) + var state = new SymbState(bool(true).typed(), arena, Binding()) // introduce an oracle that tells us which element to pick val (oracleState, oracle) = new OracleFactory(rewriter).newConstOracle(state, 2) state = oracleState def mkSet(setEx: TlaEx): ArenaCell = { - state = rewriter.rewriteUntilDone(state.setRex(tla.withType(setEx, AnnotationParser.toTla(FinSetT(IntT()))))) + state = rewriter.rewriteUntilDone(state.setRex(setEx)) state.asCell } - val sets = Seq(mkSet(tla.enumSet(tla.int(1), tla.int(2))), mkSet(tla.enumSet())) + val sets = Seq(mkSet(enumSet(int(1), int(2)).typed(types, "I")), mkSet(enumSet().typed(types, "I"))) state = new CherryPick(rewriter).pickSet(FinSetT(IntT()), state, oracle, sets, state.arena.cellFalse().toNameEx) assert(solverContext.sat()) @@ -181,9 +249,9 @@ class TestCherryPick extends RewriterBase with TestingPredefs { assertEqWhenChosen(rewriter, state, oracle, 1, sets(1).toNameEx) } - test("""CHERRY-PICK { {} } ~~> $B$k""") { + test("""CHERRY-PICK { {} }""") { val rewriter = create() - var state = new SymbState(tla.bool(true), arena, Binding()) + var state = new SymbState(bool(true).typed(), arena, Binding()) // introduce an oracle that tells us which element to pick val (oracleState, oracle) = new OracleFactory(rewriter).newConstOracle(state, 2) state = oracleState @@ -193,7 +261,7 @@ class TestCherryPick extends RewriterBase with TestingPredefs { state.asCell } - val sets = Seq(mkSet(tla.enumSet())) + val sets = Seq(mkSet(enumSet().typed(types, "I"))) state = new CherryPick(rewriter).pickSet(FinSetT(IntT()), state, oracle, sets, state.arena.cellFalse().toNameEx) assert(solverContext.sat()) @@ -202,7 +270,7 @@ class TestCherryPick extends RewriterBase with TestingPredefs { test("""CHERRY-PICK { {{1, 2}, {3, 4}}, {{5, 6}} }""") { val rewriter = create() - var state = new SymbState(tla.bool(true), arena, Binding()) + var state = new SymbState(bool(true).typed(), arena, Binding()) // introduce an oracle that tells us which element to pick val (oracleState, oracle) = new OracleFactory(rewriter).newConstOracle(state, 2) state = oracleState @@ -212,11 +280,11 @@ class TestCherryPick extends RewriterBase with TestingPredefs { state.asCell } - val set12 = tla.enumSet(tla.int(1), tla.int(2)) - val set34 = tla.enumSet(tla.int(3), tla.int(4)) - val set56 = tla.enumSet(tla.int(5), tla.int(6)) + val set12 = enumSet(int(1), int(2)) ? "I" + val set34 = enumSet(int(3), int(4)) ? "I" + val set56 = enumSet(int(5), int(6)) ? "I" val sets = - Seq(rewriteEx(tla.enumSet(set12, set34)), rewriteEx(tla.enumSet(set56))) + Seq(rewriteEx(enumSet(set12, set34).typed(types, "II")), rewriteEx(enumSet(set56).typed(types, "II"))) state = new CherryPick(rewriter).pickSet(FinSetT(FinSetT(IntT())), state, oracle, sets, state.arena.cellFalse().toNameEx) assert(solverContext.sat()) @@ -225,22 +293,24 @@ class TestCherryPick extends RewriterBase with TestingPredefs { assertEqWhenChosen(rewriter, state, oracle, 1, sets(1).toNameEx) } - test("""CHERRY-PICK { [x \in {1, 2} |-> 2 + x], [x \in {2, 3} |-> 2 * x] } ~~> $B$k""") { + test("""CHERRY-PICK { [x \in {1, 2} |-> 2 + x], [x \in {2, 3} |-> 2 * x] }""") { val rewriter = create() - var state = new SymbState(tla.bool(true), arena, Binding()) + var state = new SymbState(bool(true).typed(), arena, Binding()) // introduce an oracle that tells us which element to pick val (oracleState, oracle) = new OracleFactory(rewriter).newConstOracle(state, 2) state = oracleState - def mkFun(dom: TlaEx, map: TlaEx): ArenaCell = { - state = rewriter.rewriteUntilDone(state.setRex(tla.funDef(map, NameEx("x"), dom))) + def mkFun(dom: BuilderEx, map: BuilderEx): ArenaCell = { + val fun = funDef(map, name("x") ? "i", dom) + .typed(types, "i_to_i") + state = rewriter.rewriteUntilDone(state.setRex(fun)) state.asCell } - val set12 = tla.enumSet(tla.int(1), tla.int(2)) - val set23 = tla.enumSet(tla.int(2), tla.int(3)) - val fun1 = mkFun(set12, tla.plus(tla.int(2), tla.name("x"))) - val fun2 = mkFun(set23, tla.mult(tla.int(2), tla.name("x"))) + val set12 = enumSet(int(1), int(2)) ? "I" + val set23 = enumSet(int(2), int(3)) ? "I" + val fun1 = mkFun(set12, plus(int(2), name("x") ? "i") ? "i") + val fun2 = mkFun(set23, mult(int(2), name("x") ? "i") ? "i") val funs = Seq(fun1, fun2) val funT = FunT(FinSetT(IntT()), IntT()) state = new CherryPick(rewriter).pickFun(funT, state, oracle, funs, state.arena.cellFalse().toNameEx) diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestModelChecker.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestModelChecker.scala deleted file mode 100644 index b1c10da779..0000000000 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestModelChecker.scala +++ /dev/null @@ -1,814 +0,0 @@ -package at.forsyte.apalache.tla.bmcmt - -import at.forsyte.apalache.tla.bmcmt.analyses._ -import at.forsyte.apalache.tla.bmcmt.search.BfsStrategy -import at.forsyte.apalache.tla.bmcmt.types.eager.TrivialTypeFinder -import at.forsyte.apalache.tla.imp.src.SourceStore -import at.forsyte.apalache.tla.lir._ -import at.forsyte.apalache.tla.lir.convenience.tla -import at.forsyte.apalache.tla.lir.oper.BmcOper -import at.forsyte.apalache.tla.lir.storage.ChangeListener -import at.forsyte.apalache.tla.lir.UntypedPredefs._ -import org.junit.runner.RunWith -import org.scalatest.junit.JUnitRunner -import org.scalatest.{BeforeAndAfter, FunSuite} - -@RunWith(classOf[JUnitRunner]) -class TestModelChecker extends FunSuite with BeforeAndAfter { - private var typeFinder: TrivialTypeFinder = new TrivialTypeFinder() - private var exprGradeStore: ExprGradeStore = new ExprGradeStoreImpl() - private var hintsStore: FormulaHintsStoreImpl = new FormulaHintsStoreImpl() - private var changeListener: ChangeListener = new ChangeListener() - private var sourceStore: SourceStore = _ - - before { - // initialize the model checker - typeFinder = new TrivialTypeFinder() - exprGradeStore = new ExprGradeStoreImpl() - sourceStore = new SourceStore() - } - - after { - typeFinder.reset(Map()) - } - - test("Init, OK") { - // x' \in {2} - val initTrans = List(mkAssign("x", 2)) - val nextTrans = List(mkAssign("x", 2)) - val dummyModule = new TlaModule("root", List()) - val checkerInput = - new CheckerInput(dummyModule, initTrans, nextTrans, None, List.empty) - val strategy = new BfsStrategy(checkerInput, stepsBound = 0) - val checker = new ModelChecker( - typeFinder, - hintsStore, - changeListener, - exprGradeStore, - sourceStore, - checkerInput, - strategy, - Map(), - debug = false, - profile = false - ) - val outcome = checker.run() - assert(Checker.Outcome.NoError == outcome) - } - - test("Init, deadlock") { - // x' \in {2} /\ x' \in {1} - val initTrans = List(tla.and(mkAssign("x", 2), mkNotAssign("x", 1)).untyped()) - val nextTrans = List(mkAssign("x", 2)) - val dummyModule = new TlaModule("root", List()) - val checkerInput = - new CheckerInput(dummyModule, initTrans, nextTrans, None, List.empty) - // initialize the model checker - val strategy = new BfsStrategy(checkerInput, stepsBound = 0) - val checker = new ModelChecker( - typeFinder, - hintsStore, - changeListener, - exprGradeStore, - sourceStore, - checkerInput, - strategy, - Map(), - debug = false, - profile = false - ) - val outcome = checker.run() - assert(Checker.Outcome.Deadlock == outcome) - } - - test("Init, 2 options, OK") { - // x' \in {2} \/ x' \in {1} - val initTrans = List(mkAssign("x", 2), mkAssign("x", 1)) - val nextTrans = List(mkAssign("x", 2)) - val dummyModule = new TlaModule("root", List()) - val checkerInput = - new CheckerInput(dummyModule, initTrans, nextTrans, None, List.empty) - // initialize the model checker - val strategy = new BfsStrategy(checkerInput, stepsBound = 0) - val checker = new ModelChecker( - typeFinder, - hintsStore, - changeListener, - exprGradeStore, - sourceStore, - checkerInput, - strategy, - Map(), - debug = false, - profile = false - ) - val outcome = checker.run() - assert(Checker.Outcome.NoError == outcome) - } - - test("Init + Next, 1 step, OK") { - // x' \in {2} \/ x' \in {1} - val initTrans = List(mkAssign("x", 2), mkAssign("x", 1)) - // x' \in {x + 1} - val nextTrans = List(mkAssign("x", tla.plus(tla.name("x"), tla.int(1)))) - val dummyModule = new TlaModule("root", List()) - val checkerInput = - new CheckerInput(dummyModule, initTrans, nextTrans, None, List.empty) - // initialize the model checker - val strategy = new BfsStrategy(checkerInput, stepsBound = 1) - val checker = new ModelChecker( - typeFinder, - hintsStore, - changeListener, - exprGradeStore, - sourceStore, - checkerInput, - strategy, - Map(), - debug = false, - profile = false - ) - val outcome = checker.run() - assert(Checker.Outcome.NoError == outcome) - } - - test("Init + Next, 1 step, formula hint") { - // x' \in {1} - val initTrans = List(mkAssign("x", 1)) - // x < 10 /\ x' \in {x + 1} - val nextTrans = - List( - tla - .and( - tla.lt(tla.name("x"), tla.int(10)), - mkAssign("x", tla.plus(tla.name("x"), tla.int(1))) - ) - .untyped() - ) - /// - val dummyModule = new TlaModule("root", List()) - val checkerInput = - new CheckerInput(dummyModule, initTrans, nextTrans, None, List.empty) - - // Add the hint. We cannot check in the test, whether the hints was actually used. - // We only check that the checker works in presence of hints. - hintsStore.store.put(nextTrans.head.ID, FormulaHintsStore.HighAnd()) - // initialize the model checker - val strategy = new BfsStrategy(checkerInput, stepsBound = 1) - val checker = new ModelChecker( - typeFinder, - hintsStore, - changeListener, - exprGradeStore, - sourceStore, - checkerInput, - strategy, - Map(), - debug = false, - profile = false - ) - val outcome = checker.run() - assert(Checker.Outcome.NoError == outcome) - } - - test("determinstic Init + 2 steps (regression)") { - // y' \in {1} /\ x' \in {1} - val initTrans = List(tla.and(mkAssign("y", 1), mkAssign("x", 1)).untyped()) - // y' \in {y + 1} /\ x' \in {x + 1} - val nextTrans = List( - tla - .and( - mkAssign("y", tla.plus(tla.name("y"), tla.int(1))), - mkAssign("x", tla.plus(tla.name("x"), tla.int(1))) - ) - .untyped() - ) - /// - val dummyModule = new TlaModule("root", List()) - val inv = tla.eql( - tla.eql(tla.int(3), tla.name("x")), - tla.eql(tla.int(3), tla.name("y")) - ) //// - - val checkerInput = new CheckerInput( - dummyModule, - initTrans, - nextTrans, - None, - List((inv, tla.not(inv))) - ) - // initialize the model checker - val strategy = new BfsStrategy(checkerInput, stepsBound = 2) - val checker = new ModelChecker( - typeFinder, - hintsStore, - changeListener, - exprGradeStore, - sourceStore, - checkerInput, - strategy, - Map(), - debug = false, - profile = false - ) - val outcome = checker.run() - assert(Checker.Outcome.NoError == outcome) - } - - test("Init + 2 steps with LET-IN") { - // x' \in {1} - val initTrans = List(mkAssign("x", 1)) - // LET A == 1 + x IN x' \in { A + 1 } - val aDecl = TlaOperDecl("A", List(), tla.plus(tla.int(1), tla.name("x"))) - - val letIn = tla.letIn(tla.plus(tla.appDecl(aDecl), tla.int(1)), aDecl) - - val nextTrans = List(mkAssign("x", letIn)) - /// - val dummyModule = new TlaModule("root", List()) - val inv = tla.not(tla.eql(tla.int(4), tla.name("x"))) - - val checkerInput = new CheckerInput( - dummyModule, - initTrans, - nextTrans, - None, - List((inv, tla.not(inv))) - ) - // initialize the model checker - val strategy = new BfsStrategy(checkerInput, stepsBound = 2) - val checker = new ModelChecker( - typeFinder, - hintsStore, - changeListener, - exprGradeStore, - sourceStore, - checkerInput, - strategy, - Map(), - debug = false, - profile = false - ) - val outcome = checker.run() - assert(Checker.Outcome.NoError == outcome) - } - - test("Init + Next, 1 step, deadlock") { - // x' \in {2} \/ x' \in {1} - val initTrans = List(mkAssign("x", 2), mkAssign("x", 1)) - // x > 3 /\ x' \in {x + 1} - val nextTrans = List( - tla - .and( - tla.gt(tla.name("x"), tla.int(3)), - mkAssign("x", tla.plus(tla.name("x"), tla.int(1))) - ) - .untyped() - ) - val dummyModule = new TlaModule("root", List()) - val checkerInput = - new CheckerInput(dummyModule, initTrans, nextTrans, None, List()) - // initialize the model checker - val strategy = new BfsStrategy(checkerInput, stepsBound = 1) - val checker = new ModelChecker( - typeFinder, - hintsStore, - changeListener, - exprGradeStore, - sourceStore, - checkerInput, - strategy, - Map(), - debug = false, - profile = false - ) - val outcome = checker.run() - assert(Checker.Outcome.Deadlock == outcome) - } - - test("Init + Next, 10 steps, OK") { - // x' \in {2} \/ x' \in {1} - val initTrans = List(mkAssign("x", 2), mkAssign("x", 1)) - // x' \in {x + 1} - val nextTrans = List(mkAssign("x", tla.plus(tla.name("x"), tla.int(1)))) - val dummyModule = new TlaModule("root", List()) - val checkerInput = - new CheckerInput(dummyModule, initTrans, nextTrans, None, List()) - // initialize the model checker - val strategy = new BfsStrategy(checkerInput, stepsBound = 10) - val checker = new ModelChecker( - typeFinder, - hintsStore, - changeListener, - exprGradeStore, - sourceStore, - checkerInput, - strategy, - Map(), - debug = false, - profile = false - ) - val outcome = checker.run() - assert(Checker.Outcome.NoError == outcome) - } - - test("Init + Next, 10 steps, deadlock") { - // x' \in {2} \/ x' \in {1} - val initTrans = List(mkAssign("x", 2), mkAssign("x", 1)) - // x < 10 /\ x' \in {x + 1} - val nextTrans = List( - tla - .and( - tla.lt(tla.name("x"), tla.int(10)), - mkAssign("x", tla.plus(tla.name("x"), tla.int(1))) - ) - .untyped() - ) - val dummyModule = new TlaModule("root", List()) - val checkerInput = - new CheckerInput(dummyModule, initTrans, nextTrans, None, List()) - // initialize the model checker - val strategy = new BfsStrategy(checkerInput, stepsBound = 10) - val checker = new ModelChecker( - typeFinder, - hintsStore, - changeListener, - exprGradeStore, - sourceStore, - checkerInput, - strategy, - Map(), - debug = false, - profile = false - ) - val outcome = checker.run() - assert(Checker.Outcome.Deadlock == outcome) - } - - test("Init + Next + Inv, 10 steps, OK") { - // x' \in {2} \/ x' \in {1} - val initTrans = List(mkAssign("x", 2), mkAssign("x", 1)) - // x' \in {x + 1} - val nextTrans = List(mkAssign("x", tla.plus(tla.name("x"), tla.int(1)))) - // x < 100 - val inv = tla.lt(tla.name("x"), tla.int(100)) - val dummyModule = new TlaModule("root", List()) - val checkerInput = new CheckerInput( - dummyModule, - initTrans, - nextTrans, - None, - List((inv, tla.not(inv))) - ) - // initialize the model checker - val strategy = new BfsStrategy(checkerInput, stepsBound = 10) - val checker = new ModelChecker( - typeFinder, - hintsStore, - changeListener, - exprGradeStore, - sourceStore, - checkerInput, - strategy, - Map(), - debug = false, - profile = false - ) - val outcome = checker.run() - assert(Checker.Outcome.NoError == outcome) - } - - test("Init + Next + Inv, 10 steps, learnFromUnsat, OK") { - // x' \in {2} \/ x' \in {1} - val initTrans = List(mkAssign("x", 2), mkAssign("x", 1)) - // x' \in {x + 1} - val nextTrans = List(mkAssign("x", tla.plus(tla.name("x"), tla.int(1)))) - // x < 100 - val inv = tla.lt(tla.name("x"), tla.int(100)) - val dummyModule = new TlaModule("root", List()) - val checkerInput = new CheckerInput( - dummyModule, - initTrans, - nextTrans, - None, - List((inv, tla.not(inv))) - ) - // initialize the model checker - val strategy = new BfsStrategy(checkerInput, stepsBound = 10) - val tuning = Map("search.invariant.learnFromUnsat" -> "true") - val checker = new ModelChecker( - typeFinder, - hintsStore, - changeListener, - exprGradeStore, - sourceStore, - checkerInput, - strategy, - tuning, - debug = false, - profile = false - ) - val outcome = checker.run() - assert(Checker.Outcome.NoError == outcome) - } - - test("Init + Next + Inv, 10 steps, !search.invariant.split, OK") { - // x' \in {2} \/ x' \in {1} - val initTrans = List(mkAssign("x", 2), mkAssign("x", 1)) - // x' \in {x + 1} - val nextTrans = List(mkAssign("x", tla.plus(tla.name("x"), tla.int(1)))) - // x < 100 - val inv = tla.lt(tla.name("x"), tla.int(100)) - val dummyModule = new TlaModule("root", List()) - val checkerInput = new CheckerInput( - dummyModule, - initTrans, - nextTrans, - None, - List((inv, tla.not(inv))) - ) - // initialize the model checker - val strategy = new BfsStrategy(checkerInput, stepsBound = 10) - val tuning = Map("search.invariant.split" -> "false") - val checker = new ModelChecker( - typeFinder, - hintsStore, - changeListener, - exprGradeStore, - sourceStore, - checkerInput, - strategy, - tuning, - debug = false, - profile = false - ) - val outcome = checker.run() - assert(Checker.Outcome.NoError == outcome) - } - - test("Init + Next + Inv, 10 steps, error") { - // x' \in {2} \/ x' \in {1} - val initTrans = List(mkAssign("x", 2), mkAssign("x", 1)) - // x' \in {x + 1} - val nextTrans = List(mkAssign("x", tla.plus(tla.name("x"), tla.int(1)))) - // x < 5 - val inv = tla.lt(tla.name("x"), tla.int(5)) - val dummyModule = new TlaModule("root", List()) - val checkerInput = new CheckerInput( - dummyModule, - initTrans, - nextTrans, - None, - List((inv, tla.not(inv))) - ) - // initialize the model checker - val strategy = new BfsStrategy(checkerInput, stepsBound = 10) - val checker = new ModelChecker( - typeFinder, - hintsStore, - changeListener, - exprGradeStore, - sourceStore, - checkerInput, - strategy, - Map(), - debug = false, - profile = false - ) - val outcome = checker.run() - assert(Checker.Outcome.Error == outcome) - } - - test("Init + Next + Inv, 3 steps, error, edge case") { - // the invariant is violated in the last state of a bounded execution - - // x' \in {0} - val initTrans = List(mkAssign("x", 0)) - // x' \in {x + 1} - val nextTrans = List(mkAssign("x", tla.plus(tla.name("x"), tla.int(1)))) - // x /= 3 - val inv = tla.not(tla.eql(tla.name("x"), tla.int(3))) - val dummyModule = new TlaModule("root", List()) - val checkerInput = new CheckerInput( - dummyModule, - initTrans, - nextTrans, - None, - List((inv, tla.not(inv))) - ) - // initialize the model checker - val strategy = new BfsStrategy(checkerInput, stepsBound = 3) - val checker = new ModelChecker( - typeFinder, - hintsStore, - changeListener, - exprGradeStore, - sourceStore, - checkerInput, - strategy, - Map(), - debug = false, - profile = false - ) - val outcome = checker.run() - assert(Checker.Outcome.Error == outcome) - } - - test("Init + Next + Inv, 2 steps, no error, edge case") { - // the invariant is violated in the last state of a bounded execution - - // x' \in {0} - val initTrans = List(mkAssign("x", 0)) - // x' \in {x + 1} - val nextTrans = List(mkAssign("x", tla.plus(tla.name("x"), tla.int(1)))) - // x /= 3 - val inv = tla.not(tla.eql(tla.name("x"), tla.int(3))) - val dummyModule = new TlaModule("root", List()) - val checkerInput = new CheckerInput( - dummyModule, - initTrans, - nextTrans, - None, - List((inv, tla.not(inv))) - ) - // initialize the model checker - val strategy = new BfsStrategy(checkerInput, stepsBound = 2) - val checker = new ModelChecker( - typeFinder, - hintsStore, - changeListener, - exprGradeStore, - sourceStore, - checkerInput, - strategy, - Map(), - debug = false, - profile = false - ) - val outcome = checker.run() - assert(Checker.Outcome.NoError == outcome) - } - - test("Init + Next + Inv, 10 steps, and invariantFilter") { - // x' \in {2} \/ x' \in {1} - val initTrans = List(mkAssign("x", 2), mkAssign("x", 1)) - // x' \in {x + 1} - val nextTrans = List(mkAssign("x", tla.plus(tla.name("x"), tla.int(1)))) - // x < 5 - val inv = tla.lt(tla.name("x"), tla.int(5)) - val dummyModule = new TlaModule("root", List()) - val checkerInput = new CheckerInput( - dummyModule, - initTrans, - nextTrans, - None, - List((inv, tla.not(inv))) - ) - // initialize the model checker - val strategy = new BfsStrategy(checkerInput, stepsBound = 10) - // We require the invariant to be checked only after the second step. So we will miss invariant violation. - val tuning = Map("search.invariantFilter" -> "2") - val checker = new ModelChecker( - typeFinder, - hintsStore, - changeListener, - exprGradeStore, - sourceStore, - checkerInput, - strategy, - tuning, - debug = false, - profile = false - ) - val outcome = checker.run() - assert(Checker.Outcome.NoError == outcome) - } - - test("Init + Next, 3 steps, non-determinism but no deadlock") { - // x' \in {1} - val initTrans = List(mkAssign("x", 1)) - // x' \in {x + 1} \/ x > 100 /\ x' \in {x} - val trans1 = mkAssign("x", tla.plus(tla.name("x"), tla.int(1))) - val trans2 = - tla.and(tla.gt(tla.name("x"), tla.int(100)), mkAssign("x", tla.name("x"))).untyped() - val nextTrans = List(trans1, trans2) - val dummyModule = new TlaModule("root", List()) - val checkerInput = - new CheckerInput(dummyModule, initTrans, nextTrans, None, List()) - // initialize the model checker - val strategy = new BfsStrategy(checkerInput, stepsBound = 3) - val checker = new ModelChecker( - typeFinder, - hintsStore, - changeListener, - exprGradeStore, - sourceStore, - checkerInput, - strategy, - Map(), - debug = false, - profile = false - ) - val outcome = checker.run() - assert(Checker.Outcome.NoError == outcome) - } - - test("Init + Next + Inv, 2 steps, set assignments") { - // sets require an explicit equality, and that is where picking the next state may fail - // Init == x \in {2} /\ y \in {10} - // Next == \/ x' = x \cup {2} /\ y' = y \setminus {11} - // \/ x' = x \setminus {2} /\ y' = y \cup {11} - // Inv == 11 \in y <=> 2 \notin x - - // Init == x' = {2} /\ y = {10} - val init = tla.and( - mkAssign("x", tla.enumSet(tla.int(2))), - mkAssign("y", tla.enumSet(tla.int(10))) - ) - - // as KerA+ does not have setminus, we use a filter here - def setminus(setName: String, boundName: String, intVal: Int): TlaEx = { - tla.filter( - tla.name(boundName), - tla.name(setName), - tla.not(tla.eql(tla.name(boundName), tla.int(intVal))) - ) - } - - // Next == \/ x' = x \cup {2} /\ y' = y \setminus {11} - // \/ x' = x \setminus {2} /\ y' = y \cup {11} - val next1 = - tla.and( - mkAssign("x", tla.cup(tla.name("x"), tla.enumSet(tla.int(2)))), - mkAssign("y", setminus("y", "t1", 11)) - ) - /// - /// - val next2 = - tla.and( - mkAssign("x", setminus("x", "t2", 2)), - mkAssign("y", tla.cup(tla.name("y"), tla.enumSet(tla.int(11)))) - ) /// - - // Inv == 11 \in y <=> 2 \notin x - val inv = tla.eql( - tla.in(tla.int(11), tla.name("y")), - tla.not(tla.in(tla.int(2), tla.name("x"))) - ) //// - - val dummyModule = new TlaModule("root", List()) - val checkerInput = new CheckerInput( - dummyModule, - List(init), - List(next1, next2), - None, - List((inv, tla.not(inv))) - ) - // initialize the model checker - val strategy = new BfsStrategy(checkerInput, stepsBound = 2) - val checker = new ModelChecker( - typeFinder, - hintsStore, - changeListener, - exprGradeStore, - sourceStore, - checkerInput, - strategy, - Map(), - debug = false, - profile = false - ) - val outcome = checker.run() - assert(Checker.Outcome.NoError == outcome) - } - - test("Init + Next, 10 steps, non-determinism in init and next") { - // x' \in {0} \/ x' \in {1} - val initTrans = List(mkAssign("x", 0), mkAssign("x", 1)) - // x' \in {x + 1} \/ x > 10 /\ x' \in {x} - val trans1 = mkAssign("x", tla.plus(tla.name("x"), tla.int(1))) - val trans2 = - tla.and(tla.gt(tla.name("x"), tla.int(10)), mkAssign("x", tla.name("x"))).untyped() - val nextTrans = List(trans1, trans2) - val notInv = tla.gt(tla.prime(tla.name("x")), tla.int(10)) // ~(x <= 10) - val dummyModule = new TlaModule("root", List()) - val checkerInput = new CheckerInput( - dummyModule, - initTrans, - nextTrans, - None, - List((tla.not(notInv), notInv)) - ) - // initialize the model checker - val strategy = new BfsStrategy(checkerInput, stepsBound = 10) - val checker = new ModelChecker( - typeFinder, - hintsStore, - changeListener, - exprGradeStore, - sourceStore, - checkerInput, - strategy, - Map(), - debug = false, - profile = false - ) - val outcome = checker.run() - assert(Checker.Outcome.Error == outcome) - } - - test("cInit + Init + Next, 10 steps") { - // x' \in {0} \/ x' \in {1} - val initTrans = List(mkAssign("x", 0), mkAssign("x", 1)) - // x' \in {x + 1} \/ x > 10 /\ x' \in {x} - val trans1 = mkAssign("x", tla.plus(tla.name("x"), tla.int(1))) - val trans2 = - tla.and(tla.gt(tla.name("x"), tla.int(10)), mkAssign("x", tla.name("x"))).untyped() - val nextTrans = List(trans1, trans2) - // a constant initializer: \E t \in { 20, 10 }: N' \in {t} - val cInit = - OperEx( - BmcOper.skolem, - tla.exists( - tla.name("t"), - tla.enumSet(tla.int(20), tla.int(10)), - tla.in(tla.prime(tla.name("N")), tla.enumSet(tla.name("t"))) - ) - ) //// - - val notInv = tla.gt(tla.prime(tla.name("x")), tla.name("N")) // ~(x <= N) - val dummyModule = - new TlaModule("root", List(TlaConstDecl("N"), TlaVarDecl("x"))) - val checkerInput = new CheckerInput( - dummyModule, - initTrans, - nextTrans, - Some(cInit), - List((tla.not(notInv), notInv)) - ) - // initialize the model checker - val strategy = new BfsStrategy(checkerInput, stepsBound = 10) - val checker = new ModelChecker( - typeFinder, - hintsStore, - changeListener, - exprGradeStore, - sourceStore, - checkerInput, - strategy, - Map(), - debug = false, - profile = false - ) - val outcome = checker.run() - assert(Checker.Outcome.Error == outcome) - } - - test("Init + Next, 10 steps and filter") { - // x' \in {0} - val initTrans = List(mkAssign("x", 0)) - // x' \in {x + 1} \/ x' \in {x + 2} - val trans1 = mkAssign("x", tla.plus(tla.name("x"), tla.int(1))) - val trans2 = mkAssign("x", tla.plus(tla.name("x"), tla.int(2))) - val nextTrans = List(trans1, trans2) - val notInv = tla.gt(tla.name("x"), tla.int(11)) // ~(x <= 11) - val dummyModule = new TlaModule("root", List()) - val checkerInput = new CheckerInput( - dummyModule, - initTrans, - nextTrans, - None, - List((tla.not(notInv), notInv)) - ) - // initialize the model checker - val filter = "0,0,0,0,0,0,0,0,0,0,0" // execute initTrans once and onlytrans1 10 times - val tuning = Map("search.transitionFilter" -> filter) - val strategy = new BfsStrategy(checkerInput, stepsBound = 10) - val checker = new ModelChecker( - typeFinder, - hintsStore, - changeListener, - exprGradeStore, - sourceStore, - checkerInput, - strategy, - tuning, - debug = false, - profile = false - ) - val outcome = checker.run() - assert(Checker.Outcome.NoError == outcome) - } - - private def mkAssign(name: String, value: Int): TlaEx = - tla.assignPrime(tla.name(name), tla.int(value)) - - private def mkAssign(name: String, rhs: TlaEx): TlaEx = - tla.assignPrime(tla.name(name), rhs) - - private def mkNotAssign(name: String, value: Int): TlaEx = - tla.primeEq(tla.name(name), tla.int(value)) - - private def mkNotAssign(name: String, rhs: TlaEx): TlaEx = - tla.primeEq(tla.name(name), rhs) -} diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestRewriterKeraSet.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestRewriterKeraSet.scala deleted file mode 100644 index 282c29b559..0000000000 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestRewriterKeraSet.scala +++ /dev/null @@ -1,84 +0,0 @@ -package at.forsyte.apalache.tla.bmcmt - -import at.forsyte.apalache.tla.lir._ -import at.forsyte.apalache.tla.lir.convenience._ -import at.forsyte.apalache.tla.lir.transformations.impl.TrackerWithListeners -import at.forsyte.apalache.tla.typecheck.TypedPredefs._ -import at.forsyte.apalache.tla.pp.{Keramelizer, UniqueNameGenerator} -import at.forsyte.apalache.tla.typecheck.{BoolT1, IntT1, SetT1} -import org.junit.runner.RunWith -import org.scalatest.junit.JUnitRunner - -/** - * Tests for the TLA+ operators that are deal with by rewriting into KerA+. - * Although they are not needed to test the rewriting rules, we keep them for regression. - * - * @author Igor Konnov - */ -@RunWith(classOf[JUnitRunner]) -class TestRewriterKeraSet extends RewriterBase with TestingPredefs { - private var keramelizer = new Keramelizer(new UniqueNameGenerator, TrackerWithListeners()) - - test("""SE-SET-CAP[1-2]: {1, 3} \cap {3, 4} = {3}""") { - val types = Map("S" -> SetT1(IntT1()), "i" -> IntT1()) - val left = tla.enumSet(tla.int(1), tla.int(3)) - val right = tla.enumSet(tla.int(3), tla.int(4)) - val expected = tla - .enumSet(tla.int(3)) - .typed(types, "S") - val intersection = tla - .cap(left ? "S", right ? "S") - .typed(types, "S") - val capSet = keramelizer.transform(intersection) - val eqExpected = tla.eql(capSet, expected).typed(BoolT1()) - - val state = new SymbState(eqExpected, arena, Binding()) - val rewriter = create() - assertTlaExAndRestore(rewriter, state) - } - - test("""SE-SET-DIFF[1-2]: {1, 3, 5} \cap {1, 4} = {3, 5}""") { - val types = Map("S" -> SetT1(IntT1()), "i" -> IntT1()) - val left = tla.enumSet(tla.int(1), tla.int(3), tla.int(5)) ? "S" - val right = tla.enumSet(tla.int(1), tla.int(4)) ? "S" - val expected = tla - .enumSet(tla.int(3), tla.int(5)) - .typed(types, "S") - val diff = tla - .setminus(left, right) - .typed(types, "S") - val minusSet = keramelizer.transform(diff) - val eqExpected = tla - .eql(minusSet, expected) - .typed(BoolT1()) - - val state = new SymbState(eqExpected, arena, Binding()) - val rewriter = create() - assertTlaExAndRestore(rewriter, state) - } - - test("""SE-SET-CUP: regression""") { - val types = Map("S" -> SetT1(IntT1()), "i" -> IntT1()) - // 2019-01-18, Igor: this bug originally appeared in TwoPhase.tla, the MWE can be found in Bug20190118.tla - // S = {1} \ {1} - val set1 = tla.setminus(tla.enumSet(tla.int(1)) ? "S", tla.enumSet(tla.int(1)) ? "S") ? "S" - // T = {3} \ 3 - val set2 = tla.setminus(tla.enumSet(tla.int(3)) ? "S", tla.enumSet(tla.int(3)) ? "S") ? "S" - // R = S \cup T = {} - // the buggy implementation will try in(1, T) and this may return true! - val set3 = tla - .cup(set1, set2) - .typed(types, "S") - val membership = tla - .in(tla.int(1), set3) - .typed(BoolT1()) - val trans = keramelizer.transform(membership) - val state = new SymbState(trans, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - solverContext.assertGroundExpr(nextState.ex) - // the buggy implementation had: 1 \ in R - assertUnsatOrExplain(rewriter, nextState) - } - -} diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSeqModelChecker.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSeqModelChecker.scala index 8f521875d4..2265fcc057 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSeqModelChecker.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSeqModelChecker.scala @@ -8,43 +8,48 @@ import at.forsyte.apalache.tla.bmcmt.smt.{RecordingSolverContext, SolverConfig} import at.forsyte.apalache.tla.bmcmt.trex.{ FilteredTransitionExecutor, IncrementalExecutionContext, TransitionExecutorImpl } -import at.forsyte.apalache.tla.bmcmt.types.eager.TrivialTypeFinder import at.forsyte.apalache.tla.lir._ -import at.forsyte.apalache.tla.lir.convenience.tla -import at.forsyte.apalache.tla.lir.oper.BmcOper -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.lir.convenience.tla._ +import at.forsyte.apalache.tla.lir.TypedPredefs._ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{BeforeAndAfter, FunSuite} @RunWith(classOf[JUnitRunner]) class TestSeqModelChecker extends FunSuite with BeforeAndAfter { - private var typeFinder: TrivialTypeFinder = new TrivialTypeFinder() private var solver: RecordingSolverContext = RecordingSolverContext.createZ3(None, SolverConfig(debug = false, profile = false, 0)) - private var rewriter = new SymbStateRewriterImpl(solver, typeFinder, new ExprGradeStoreImpl) + private var rewriter = new SymbStateRewriterImpl(solver, new ExprGradeStoreImpl) + private val types = Map( + "i" -> IntT1(), + "I" -> SetT1(IntT1()), + "b" -> BoolT1(), + "Ob" -> OperT1(Seq(), BoolT1()) + ) + private val intTag: Typed[TlaType1] = Typed(IntT1()) before { // initialize the model checker - typeFinder = new TrivialTypeFinder() solver = RecordingSolverContext.createZ3(None, SolverConfig(debug = false, profile = false, 0)) - rewriter = new SymbStateRewriterImpl(solver, typeFinder, new ExprGradeStoreImpl) + rewriter = new SymbStateRewriterImpl(solver, new ExprGradeStoreImpl) } private def mkModuleWithX(): TlaModule = { - new TlaModule("root", List(TlaVarDecl("x"))) + new TlaModule("root", List(TlaVarDecl("x")(Typed(IntT1())))) } private def mkModuleWithXandY(): TlaModule = { - new TlaModule("root", List(TlaVarDecl("x"), TlaVarDecl("y"))) + new TlaModule("root", List(TlaVarDecl("x")(intTag), TlaVarDecl("y")(intTag))) } test("Init + Inv => OK") { // x' <- 2 val initTrans = List(mkAssign("x", 2)) val nextTrans = List(mkAssign("x", 2)) - val notInv = tla.gt(tla.name("x"), tla.int(10)) - val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List((tla.not(notInv), notInv))) + val notInv = gt(name("x") ? "i", int(10)) + .typed(types, "b") + val inv = not(notInv).typed(types, "b") + val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List((inv, notInv))) val params = new ModelCheckerParams(checkerInput, stepsBound = 0, new File("."), Map(), false) val ctx = new IncrementalExecutionContext(rewriter) val trex = new TransitionExecutorImpl(params.consts, params.vars, ctx) @@ -57,8 +62,10 @@ class TestSeqModelChecker extends FunSuite with BeforeAndAfter { // x' <- 2 val initTrans = List(mkAssign("x", 2)) val nextTrans = List(mkAssign("x", 2)) - val notInv = tla.lt(tla.name("x"), tla.int(10)) - val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List((tla.not(notInv), notInv))) + val notInv = lt(name("x") ? "i", int(10)) + .typed(types, "b") + val inv = not(notInv).typed(types, "b") + val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List((inv, notInv))) val params = new ModelCheckerParams(checkerInput, stepsBound = 0, new File("."), Map(), false) val ctx = new IncrementalExecutionContext(rewriter) val trex = new TransitionExecutorImpl(params.consts, params.vars, ctx) @@ -71,12 +78,14 @@ class TestSeqModelChecker extends FunSuite with BeforeAndAfter { // N' <- 10 val cinit = mkAssign("N", 10) // x' <- N - val initTrans = List(mkAssign("x", tla.name("N"))) - val nextTrans = List(mkAssign("x", tla.name("N"))) - val module = new TlaModule("root", List(TlaConstDecl("N"), TlaVarDecl("x"))) - val notInv = tla.lt(tla.name("x"), tla.int(10)) - - val checkerInput = new CheckerInput(module, initTrans, nextTrans, Some(cinit), List((tla.not(notInv), notInv))) + val initTrans = List(mkAssign("x", name("N") ? "i", IntT1())) + val nextTrans = List(mkAssign("x", name("N") ? "i", IntT1())) + val module = new TlaModule("root", List(TlaConstDecl("N")(intTag), TlaVarDecl("x")(intTag))) + val notInv = lt(name("x") ? "i", int(10)) + .typed(types, "b") + val inv = not(notInv).typed(types, "b") + + val checkerInput = new CheckerInput(module, initTrans, nextTrans, Some(cinit), List((inv, notInv))) val params = new ModelCheckerParams(checkerInput, stepsBound = 0, new File("."), Map(), false) val ctx = new IncrementalExecutionContext(rewriter) val trex = new TransitionExecutorImpl(params.consts, params.vars, ctx) @@ -89,12 +98,14 @@ class TestSeqModelChecker extends FunSuite with BeforeAndAfter { // N' <- 10 val cinit = mkAssign("N", 5) // x' <- N - val initTrans = List(mkAssign("x", tla.name("N"))) - val nextTrans = List(mkAssign("x", tla.name("N"))) - val module = new TlaModule("root", List(TlaConstDecl("N"), TlaVarDecl("x"))) - val notInv = tla.lt(tla.name("x"), tla.int(10)) - - val checkerInput = new CheckerInput(module, initTrans, nextTrans, Some(cinit), List((tla.not(notInv), notInv))) + val initTrans = List(mkAssign("x", name("N") ? "i", IntT1())) + val nextTrans = List(mkAssign("x", name("N") ? "i", IntT1())) + val module = new TlaModule("root", List(TlaConstDecl("N")(intTag), TlaVarDecl("x")(intTag))) + val notInv = lt(name("x") ? "i", int(10)) + .typed(types, "b") + val inv = not(notInv).typed(types, "b") + + val checkerInput = new CheckerInput(module, initTrans, nextTrans, Some(cinit), List((inv, notInv))) val params = new ModelCheckerParams(checkerInput, stepsBound = 0, new File("."), Map(), false) val ctx = new IncrementalExecutionContext(rewriter) val trex = new TransitionExecutorImpl(params.consts, params.vars, ctx) @@ -105,7 +116,7 @@ class TestSeqModelChecker extends FunSuite with BeforeAndAfter { test("Init, deadlock") { // x' <- 2 /\ x' <- 1 - val initTrans = List(tla.and(mkAssign("x", 2), mkNotAssign("x", 1)).untyped()) + val initTrans = List(and(mkAssign("x", 2), mkNotAssign("x", 1)).typed(BoolT1())) val nextTrans = List(mkAssign("x", 2)) val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List.empty) val params = new ModelCheckerParams(checkerInput, stepsBound = 0, new File("."), Map(), false) @@ -135,7 +146,7 @@ class TestSeqModelChecker extends FunSuite with BeforeAndAfter { // x' <- 2 \/ x' <- 1 val initTrans = List(mkAssign("x", 2), mkAssign("x", 1)) // x' <- x + 1 - val nextTrans = List(mkAssign("x", tla.plus(tla.name("x"), tla.int(1)))) + val nextTrans = List(mkAssign("x", plus(name("x") ? "i", int(1)) ? "i", IntT1())) val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List.empty) val params = new ModelCheckerParams(checkerInput, stepsBound = 1, new File("."), Map(), false) // initialize the model checker @@ -150,10 +161,12 @@ class TestSeqModelChecker extends FunSuite with BeforeAndAfter { // x' <- 2 \/ x' <- 1 val initTrans = List(mkAssign("x", 2), mkAssign("x", 1)) // x' <- x + 1 - val nextTrans = List(mkAssign("x", tla.plus(tla.name("x"), tla.int(1)))) + val nextTrans = List(mkAssign("x", plus(name("x") ? "i", int(1)) ? "i", IntT1())) // x < 5 - val inv = tla.lt(tla.name("x"), tla.int(5)) - val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List((inv, tla.not(inv)))) + val inv = lt(name("x") ? "i", int(5)) + .typed(types, "b") + val notInv = not(inv).typed(types, "b") + val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List((inv, notInv))) val params = new ModelCheckerParams(checkerInput, stepsBound = 10, new File("."), Map(), false) params.discardDisabled = true params.invariantMode = InvariantMode.BeforeJoin @@ -169,10 +182,12 @@ class TestSeqModelChecker extends FunSuite with BeforeAndAfter { // x' <- 2 \/ x' <- 1 val initTrans = List(mkAssign("x", 2), mkAssign("x", 1)) // x' <- x + 1 - val nextTrans = List(mkAssign("x", tla.plus(tla.name("x"), tla.int(1)))) + val nextTrans = List(mkAssign("x", plus(name("x") ? "i", int(1)) ? "i", IntT1())) // x < 5 - val inv = tla.lt(tla.name("x"), tla.int(5)) - val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List((inv, tla.not(inv)))) + val inv = lt(name("x") ? "i", int(5)) + .typed(types, "b") + val notInv = not(inv).typed(types, "b") + val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List((inv, notInv))) val params = new ModelCheckerParams(checkerInput, stepsBound = 10, new File("."), Map(), false) params.discardDisabled = false params.invariantMode = InvariantMode.BeforeJoin @@ -188,10 +203,12 @@ class TestSeqModelChecker extends FunSuite with BeforeAndAfter { // x' <- 2 \/ x' <- 1 val initTrans = List(mkAssign("x", 2), mkAssign("x", 1)) // x' <- x + 1 - val nextTrans = List(mkAssign("x", tla.plus(tla.name("x"), tla.int(1)))) + val nextTrans = List(mkAssign("x", plus(name("x") ? "i", int(1)) ? "i", IntT1())) // x < 5 - val inv = tla.lt(tla.name("x"), tla.int(5)) - val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List((inv, tla.not(inv)))) + val inv = lt(name("x") ? "i", int(5)) + .typed(types, "b") + val notInv = not(inv).typed(types, "b") + val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List((inv, notInv))) val params = new ModelCheckerParams(checkerInput, stepsBound = 10, new File("."), Map(), false) params.discardDisabled = false params.invariantMode = InvariantMode.AfterJoin @@ -207,10 +224,12 @@ class TestSeqModelChecker extends FunSuite with BeforeAndAfter { // x' <- 2 \/ x' <- 1 val initTrans = List(mkAssign("x", 2), mkAssign("x", 1)) // x' <- x + 1 - val nextTrans = List(mkAssign("x", tla.plus(tla.name("x"), tla.int(1)))) + val nextTrans = List(mkAssign("x", plus(name("x") ? "i", int(1)) ? "i", IntT1())) // x < 5 - val inv = tla.lt(tla.name("x"), tla.int(5)) - val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List((inv, tla.not(inv)))) + val inv = lt(name("x") ? "i", int(5)) + .typed(types, "b") + val notInv = not(inv).typed(types, "b") + val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List((inv, notInv))) val params = new ModelCheckerParams(checkerInput, stepsBound = 10, new File("."), Map(), false) params.discardDisabled = true params.invariantMode = InvariantMode.AfterJoin @@ -226,13 +245,19 @@ class TestSeqModelChecker extends FunSuite with BeforeAndAfter { // x' <- 1 val initTrans = List(mkAssign("x", 1)) // x' <- x + 1 - val assign = mkAssign("x", tla.plus(tla.name("x"), tla.int(1))) + val assign = mkAssign("x", plus(name("x") ? "i", int(1)) ? "i", IntT1()) val nextTrans = List(assign) // x < 3 - val lt = tla.lt(tla.name("x"), tla.int(3)) - val letIn = tla.letIn(tla.appOp(tla.name("Foo")), tla.declOp("Foo", lt).untypedOperDecl()) - val inv = letIn - val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List((inv, tla.not(inv)))) + val pred = lt(name("x") ? "i", int(3)) + .typed(types, "b") + val letDef = letIn(appOp(name("Foo") ? "Ob") ? "b", + declOp("Foo", pred) + .typedOperDecl(types, "Ob")) + val inv = letDef + .typed(types, "b") + val notInv = not(inv).typed(types, "b") + + val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List((inv, notInv))) val params = new ModelCheckerParams(checkerInput, stepsBound = 2, new File("."), Map(), false) // initialize the model checker val ctx = new IncrementalExecutionContext(rewriter) @@ -244,21 +269,20 @@ class TestSeqModelChecker extends FunSuite with BeforeAndAfter { test("determinstic Init + 2 steps (regression)") { // y' <- 1 /\ x' <- 1 - val initTrans = List(tla.and(mkAssign("y", 1), mkAssign("x", 1)).untyped()) + val initTrans = List(and(mkAssign("y", 1), mkAssign("x", 1)).typed(BoolT1())) // y' <- y + 1 /\ x' <- x + 1 val nextTrans = List( - tla - .and( - mkAssign("y", tla.plus(tla.name("y"), tla.int(1))), - mkAssign("x", tla.plus(tla.name("x"), tla.int(1))) - ) - .untyped()) /// - val inv = tla.eql( - tla.eql(tla.int(3), tla.name("x")), - tla.eql(tla.int(3), tla.name("y")) - ) //// - - val checkerInput = new CheckerInput(mkModuleWithXandY(), initTrans, nextTrans, None, List((inv, tla.not(inv)))) + and( + mkAssign("y", plus(name("y") ? "i", int(1)) ? "i", IntT1()), + mkAssign("x", plus(name("x") ? "i", int(1)) ? "i", IntT1()) + ).typed(BoolT1())) /// + val inv = eql( + eql(int(3), name("x") ? "i") ? "b", + eql(int(3), name("y") ? "i") ? "b" + ).typed(types, "b") + val notInv = not(inv).typed(types, "b") + + val checkerInput = new CheckerInput(mkModuleWithXandY(), initTrans, nextTrans, None, List((inv, notInv))) val params = new ModelCheckerParams(checkerInput, stepsBound = 2, new File("."), Map(), false) // initialize the model checker val ctx = new IncrementalExecutionContext(rewriter) @@ -273,7 +297,8 @@ class TestSeqModelChecker extends FunSuite with BeforeAndAfter { val initTrans = List(mkAssign("x", 2), mkAssign("x", 1)) // x > 3 /\ x' <- x + 1 val nextTrans = - List(tla.and(tla.gt(tla.name("x"), tla.int(3)), mkAssign("x", tla.plus(tla.name("x"), tla.int(1)))).untyped()) + List(and(gt(name("x") ? "i", int(3)) ? "b", mkAssign("x", plus(name("x") ? "i", int(1)) ? "i", IntT1())).typed( + types, "b")) val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List()) val params = new ModelCheckerParams(checkerInput, stepsBound = 1, new File("."), Map(), false) // initialize the model checker @@ -288,7 +313,7 @@ class TestSeqModelChecker extends FunSuite with BeforeAndAfter { // x' <- 2 \/ x' <- 1 val initTrans = List(mkAssign("x", 2), mkAssign("x", 1)) // x' <- x + 1 - val nextTrans = List(mkAssign("x", tla.plus(tla.name("x"), tla.int(1)))) + val nextTrans = List(mkAssign("x", plus(name("x") ? "i", int(1)) ? "i", IntT1())) val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List()) val params = new ModelCheckerParams(checkerInput, stepsBound = 10, new File("."), Map(), false) // initialize the model checker @@ -304,7 +329,8 @@ class TestSeqModelChecker extends FunSuite with BeforeAndAfter { val initTrans = List(mkAssign("x", 2), mkAssign("x", 1)) // x < 10 /\ x' <- x + 1 val nextTrans = - List(tla.and(tla.lt(tla.name("x"), tla.int(10)), mkAssign("x", tla.plus(tla.name("x"), tla.int(1)))).untyped()) + List(and(lt(name("x") ? "i", int(10)) ? "b", mkAssign("x", plus(name("x") ? "i", int(1)) ? "i", IntT1())).typed( + types, "b")) val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List()) val params = new ModelCheckerParams(checkerInput, stepsBound = 10, new File("."), Map(), false) // initialize the model checker @@ -319,10 +345,13 @@ class TestSeqModelChecker extends FunSuite with BeforeAndAfter { // x' <- 2 \/ x' <- 1 val initTrans = List(mkAssign("x", 2), mkAssign("x", 1)) // x' <- x + 1 - val nextTrans = List(mkAssign("x", tla.plus(tla.name("x"), tla.int(1)))) + val nextTrans = List(mkAssign("x", plus(name("x") ? "i", int(1)) ? "i", IntT1())) // x < 100 - val inv = tla.lt(tla.name("x"), tla.int(100)) - val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List((inv, tla.not(inv)))) + val inv = lt(name("x") ? "i", int(100)) + .typed(types, "b") + val notInv = not(inv) + .typed(types, "b") + val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List((inv, notInv))) val params = new ModelCheckerParams(checkerInput, stepsBound = 10, new File("."), Map(), false) // initialize the model checker val ctx = new IncrementalExecutionContext(rewriter) @@ -338,10 +367,13 @@ class TestSeqModelChecker extends FunSuite with BeforeAndAfter { // x' <- 0 val initTrans = List(mkAssign("x", 0)) // x' <- x + 1 - val nextTrans = List(mkAssign("x", tla.plus(tla.name("x"), tla.int(1)))) + val nextTrans = List(mkAssign("x", plus(name("x") ? "i", int(1)) ? "i", IntT1())) // x /= 3 - val inv = tla.not(tla.eql(tla.name("x"), tla.int(3))) - val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List((inv, tla.not(inv)))) + val notInv = eql(name("x") ? "i", int(3)) + .typed(types, "b") + val inv = not(notInv) + .typed(BoolT1()) + val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List((inv, notInv))) val params = new ModelCheckerParams(checkerInput, stepsBound = 3, new File("."), Map(), false) // initialize the model checker val ctx = new IncrementalExecutionContext(rewriter) @@ -356,10 +388,13 @@ class TestSeqModelChecker extends FunSuite with BeforeAndAfter { // x' <- 0 val initTrans = List(mkAssign("x", 0)) // x' <- x + 1 - val nextTrans = List(mkAssign("x", tla.plus(tla.name("x"), tla.int(1)))) + val nextTrans = List(mkAssign("x", plus(name("x") ? "i", int(1)) ? "i", IntT1())) // x /= 3 - val inv = tla.not(tla.eql(tla.name("x"), tla.int(3))) - val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List((inv, tla.not(inv)))) + val notInv = eql(name("x") ? "i", int(3)) + .typed(types, "b") + val inv = not(notInv) + .typed(BoolT1()) + val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List((inv, notInv))) val params = new ModelCheckerParams(checkerInput, stepsBound = 2, new File("."), Map(), false) // initialize the model checker val ctx = new IncrementalExecutionContext(rewriter) @@ -373,10 +408,13 @@ class TestSeqModelChecker extends FunSuite with BeforeAndAfter { // x' <- 2 \/ x' <- 1 val initTrans = List(mkAssign("x", 2), mkAssign("x", 1)) // x' <- x + 1 - val nextTrans = List(mkAssign("x", tla.plus(tla.name("x"), tla.int(1)))) + val nextTrans = List(mkAssign("x", plus(name("x") ? "i", int(1)) ? "i", IntT1())) // x < 5 - val inv = tla.lt(tla.name("x"), tla.int(5)) - val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List((inv, tla.not(inv)))) + val inv = lt(name("x") ? "i", int(5)) + .typed(types, "b") + val notInv = not(inv) + .typed(types, "b") + val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List((inv, notInv))) // initialize the model checker // We require the invariant to be checked only after the second step. So we will miss invariant violation. val tuning = Map("search.invariantFilter" -> "2") @@ -393,8 +431,9 @@ class TestSeqModelChecker extends FunSuite with BeforeAndAfter { // x' <- 1 val initTrans = List(mkAssign("x", 1)) // x' <- x + 1 \/ x > 100 /\ x' <- x - val trans1 = mkAssign("x", tla.plus(tla.name("x"), tla.int(1))) - val trans2 = tla.and(tla.gt(tla.name("x"), tla.int(100)), mkAssign("x", tla.name("x"))).untyped() + val trans1 = mkAssign("x", plus(name("x") ? "i", int(1)) ? "i", IntT1()) + val trans2 = and(gt(name("x") ? "i", int(100)) ? "b", mkAssign("x", name("x") ? "i", IntT1())) + .typed(types, "b") val nextTrans = List(trans1, trans2) val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List()) val params = new ModelCheckerParams(checkerInput, stepsBound = 3, new File("."), Map(), false) @@ -410,11 +449,16 @@ class TestSeqModelChecker extends FunSuite with BeforeAndAfter { // x' <- 0 \/ x' <- 1 val initTrans = List(mkAssign("x", 0), mkAssign("x", 1)) // x' <- x + 1 \/ x > 10 /\ x' <- x - val trans1 = mkAssign("x", tla.plus(tla.name("x"), tla.int(1))) - val trans2 = tla.and(tla.gt(tla.name("x"), tla.int(10)), mkAssign("x", tla.name("x"))).untyped() + val trans1 = mkAssign("x", plus(name("x") ? "i", int(1)) ? "i", IntT1()) + val trans2 = and(gt(name("x") ? "i", int(10)) ? "b", mkAssign("x", name("x") ? "i", IntT1())) + .typed(types, "b") val nextTrans = List(trans1, trans2) - val notInv = tla.gt(tla.name("x"), tla.int(10)) // ~(x <= 10) - val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List((tla.not(notInv), notInv))) + // ~(x <= 10) + val notInv = gt(name("x") ? "i", int(10)) + .typed(types, "b") + val inv = not(notInv) + .typed(types, "b") + val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List((inv, notInv))) val params = new ModelCheckerParams(checkerInput, stepsBound = 10, new File("."), Map(), false) // initialize the model checker val ctx = new IncrementalExecutionContext(rewriter) @@ -428,21 +472,27 @@ class TestSeqModelChecker extends FunSuite with BeforeAndAfter { // x' <- 0 \/ x' <- 1 val initTrans = List(mkAssign("x", 0), mkAssign("x", 1)) // x' <- x + 1 \/ x > 10 /\ x' <- x - val trans1 = mkAssign("x", tla.plus(tla.name("x"), tla.int(1))) - val trans2 = tla.and(tla.gt(tla.name("x"), tla.int(10)), mkAssign("x", tla.name("x"))).untyped() + val trans1 = mkAssign("x", plus(name("x") ? "i", int(1)) ? "i", IntT1()) + val trans2 = and(gt(name("x") ? "i", int(10)) ? "b", mkAssign("x", name("x") ? "i", IntT1())) + .typed(types, "b") val nextTrans = List(trans1, trans2) // a constant initializer: \E t \in { 20, 10 }: N' \in {t} val cInit = - OperEx(BmcOper.skolem, - tla.exists( - tla.name("t"), - tla.enumSet(tla.int(20), tla.int(10)), - mkAssign("N", tla.name("t")) - )) //// - - val notInv = tla.gt(tla.name("x"), tla.name("N")) // ~(x <= N) - val dummyModule = new TlaModule("root", List(TlaConstDecl("N"), TlaVarDecl("x"))) - val checkerInput = new CheckerInput(dummyModule, initTrans, nextTrans, Some(cInit), List((tla.not(notInv), notInv))) + apalacheSkolem( + exists( + name("t") ? "i", + enumSet(int(20), int(10)) ? "I", + mkAssign("N", name("t") ? "i", IntT1()) + ) ? "b") + .typed(types, "b") + + // ~(x <= N) + val notInv = gt(name("x") ? "i", name("N") ? "i") + .typed(types, "b") + val inv = not(notInv) + .typed(types, "b") + val dummyModule = new TlaModule("root", List(TlaConstDecl("N")(intTag), TlaVarDecl("x")(intTag))) + val checkerInput = new CheckerInput(dummyModule, initTrans, nextTrans, Some(cInit), List((inv, notInv))) val params = new ModelCheckerParams(checkerInput, stepsBound = 10, new File("."), Map(), false) // initialize the model checker val ctx = new IncrementalExecutionContext(rewriter) @@ -456,11 +506,15 @@ class TestSeqModelChecker extends FunSuite with BeforeAndAfter { // x' <- 0 val initTrans = List(mkAssign("x", 0)) // x' <- x + 1 \/ x' <- x + 2 - val trans1 = mkAssign("x", tla.plus(tla.name("x"), tla.int(1))) - val trans2 = mkAssign("x", tla.plus(tla.name("x"), tla.int(2))) + val trans1 = mkAssign("x", plus(name("x") ? "i", int(1)) ? "i", IntT1()) + val trans2 = mkAssign("x", plus(name("x") ? "i", int(2)) ? "i", IntT1()) val nextTrans = List(trans1, trans2) - val notInv = tla.gt(tla.name("x"), tla.int(11)) // ~(x <= 11) - val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List((tla.not(notInv), notInv))) + // ~(x <= 11) + val notInv = gt(name("x") ? "i", int(11)) + .typed(types, "b") + val inv = not(notInv) + .typed(types, "b") + val checkerInput = new CheckerInput(mkModuleWithX(), initTrans, nextTrans, None, List((inv, notInv))) // initialize the model checker val filter = "0,0,0,0,0,0,0,0,0,0,0" // old syntax val tuning = Map.empty[String, String] // Map("search.transitionFilter" -> filter) @@ -473,15 +527,23 @@ class TestSeqModelChecker extends FunSuite with BeforeAndAfter { assert(Checker.Outcome.NoError == outcome) } - private def mkAssign(name: String, value: Int): TlaEx = - tla.assignPrime(tla.name(name), tla.int(value)) + private def mkAssign(varName: String, value: Int): TlaEx = { + assign(prime(name(varName) ? "i") ? "i", int(value)) + .typed(types, "b") + } - private def mkAssign(name: String, rhs: TlaEx): TlaEx = - tla.assignPrime(tla.name(name), rhs) + private def mkAssign(varName: String, rhs: BuilderEx, tt: TlaType1): TlaEx = { + assign(prime(name(varName) ? "_tt") ? "_tt", rhs) + .typed(types + ("_tt" -> tt), "b") + } - private def mkNotAssign(name: String, value: Int): TlaEx = - tla.primeEq(tla.name(name), tla.int(value)) + private def mkNotAssign(varName: String, value: Int): TlaEx = { + eql(prime(name(varName) ? "i") ? "i", int(value) ? "i") + .typed(types, "b") + } - private def mkNotAssign(name: String, rhs: TlaEx): TlaEx = - tla.primeEq(tla.name(name), rhs) + private def mkNotAssign(varName: String, rhs: BuilderEx, tt: TlaType1): TlaEx = { + eql(prime(name(varName) ? "_tt") ? "_tt", rhs) + .typed(types + ("_tt" -> tt), "b") + } } diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateDecoder.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateDecoder.scala index 884ab96966..0025819a05 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateDecoder.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateDecoder.scala @@ -1,17 +1,27 @@ package at.forsyte.apalache.tla.bmcmt -import at.forsyte.apalache.tla.bmcmt.types.{AnnotationParser, FinSetT, IntT, SeqT} -import at.forsyte.apalache.tla.lir.{TlaEx, ValEx} -import at.forsyte.apalache.tla.lir.convenience.tla -import at.forsyte.apalache.tla.lir.values.{TlaIntSet, TlaNatSet} -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.lir.TypedPredefs._ +import at.forsyte.apalache.tla.lir.convenience.tla._ +import at.forsyte.apalache.tla.lir._ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestSymbStateDecoder extends RewriterBase { + private val types = Map( + "i" -> IntT1(), + "I" -> SetT1(IntT1()), + "II" -> SetT1(SetT1(IntT1())), + "Qi" -> SeqT1(IntT1()), + "iii" -> TupT1(IntT1(), IntT1(), IntT1()), + "rib" -> RecT1("a" -> IntT1(), "b" -> BoolT1()), + "b" -> BoolT1(), + "i_to_i" -> FunT1(IntT1(), IntT1()), + "i_TO_i" -> SetT1(FunT1(IntT1(), IntT1())) + ) + test("decode bool") { - val originalEx: TlaEx = tla.bool(true) + val originalEx: TlaEx = bool(true).typed() val state = new SymbState(originalEx, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -21,11 +31,13 @@ class TestSymbStateDecoder extends RewriterBase { val decodedEx = decoder.decodeCellToTlaEx(nextState.arena, cell) assert(originalEx == decodedEx) // hard core comparison - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(decodedEx, originalEx))) + val eq = eql(decodedEx ? "b", originalEx ? "b") + .typed(types, "b") + assertTlaExAndRestore(rewriter, nextState.setRex(eq)) } test("decode int") { - val originalEx: TlaEx = tla.int(3) + val originalEx = int(3).typed() val state = new SymbState(originalEx, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -35,11 +47,13 @@ class TestSymbStateDecoder extends RewriterBase { val decodedEx = decoder.decodeCellToTlaEx(nextState.arena, cell) assert(originalEx == decodedEx) // hard core comparison - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(decodedEx, originalEx))) + val eq = eql(decodedEx ? "b", originalEx ? "b") + .typed(types, "b") + assertTlaExAndRestore(rewriter, nextState.setRex(eq)) } test("decode str") { - val originalEx: TlaEx = tla.str("hello") + val originalEx: TlaEx = str("hello").typed() val state = new SymbState(originalEx, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -49,11 +63,14 @@ class TestSymbStateDecoder extends RewriterBase { val decodedEx = decoder.decodeCellToTlaEx(nextState.arena, cell) assert(originalEx == decodedEx) // hard core comparison - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(decodedEx, originalEx))) + val eq = eql(decodedEx, originalEx) + .typed(types, "b") + assertTlaExAndRestore(rewriter, nextState.setRex(eq)) } test("decode Int set") { - val originalEx = ValEx(TlaIntSet) + val originalEx = intSet() + .typed(types, "I") val state = new SymbState(originalEx, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -65,7 +82,8 @@ class TestSymbStateDecoder extends RewriterBase { } test("decode Nat set") { - val originalEx = ValEx(TlaNatSet) + val originalEx = natSet() + .typed(types, "I") val state = new SymbState(originalEx, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -77,8 +95,10 @@ class TestSymbStateDecoder extends RewriterBase { } test("decode set") { - val originalEx = tla.enumSet(tla.int(2), tla.int(1), tla.int(2)) - val simpleOriginalEx: TlaEx = tla.enumSet(tla.int(1), tla.int(2)) + val originalEx = enumSet(int(2), int(1), int(2)) + .typed(types, "I") + val simpleOriginalEx: TlaEx = enumSet(int(1), int(2)) + .typed(types, "I") val state = new SymbState(originalEx, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -88,13 +108,18 @@ class TestSymbStateDecoder extends RewriterBase { val decodedEx = decoder.decodeCellToTlaEx(nextState.arena, cell) assert(simpleOriginalEx == decodedEx) // hard core comparison - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(decodedEx, simpleOriginalEx))) + val eq = eql(decodedEx ? "I", simpleOriginalEx ? "I") + .typed(types, "b") + assertTlaExAndRestore(rewriter, nextState.setRex(eq)) } test("decode fun set") { - val domEx = tla.enumSet(tla.int(1), tla.int(2)) - val cdmEx = tla.enumSet(tla.int(3), tla.int(4)) - val originalEx: TlaEx = tla.funSet(domEx, cdmEx) + val domEx = enumSet(int(1), int(2)) + .typed(types, "I") + val cdmEx = enumSet(int(3), int(4)) + .typed(types, "I") + val originalEx = funSet(domEx, cdmEx) + .typed(types, "i_TO_i") val state = new SymbState(originalEx, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -104,12 +129,16 @@ class TestSymbStateDecoder extends RewriterBase { val decodedEx = decoder.decodeCellToTlaEx(nextState.arena, cell) assert(originalEx == decodedEx) // hard core comparison - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(decodedEx, originalEx))) + val eq = eql(decodedEx, originalEx) + .typed(types, "b") + assertTlaExAndRestore(rewriter, nextState.setRex(eq)) } test("decode SUBSET S") { - val set = tla.enumSet(tla.int(1), tla.int(2)) - val powset: TlaEx = tla.powSet(set) + val set = enumSet(int(1), int(2)) + .typed(types, "I") + val powset = powSet(set) + .typed(types, "II") val state = new SymbState(powset, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -121,8 +150,10 @@ class TestSymbStateDecoder extends RewriterBase { } test("decode fun") { - val domEx = tla.enumSet(tla.int(1), tla.int(2)) - val funEx = tla.funDef(tla.plus(tla.name("x"), tla.int(1)), tla.name("x"), domEx) + val domEx = enumSet(int(1), int(2)) + .typed(types, "I") + val funEx = funDef(plus(name("x") ? "i", int(1)) ? "i", name("x") ? "i", domEx) + .typed(types, "i_to_i") val state = new SymbState(funEx, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -131,15 +162,18 @@ class TestSymbStateDecoder extends RewriterBase { val decoder = new SymbStateDecoder(solverContext, rewriter) val decodedEx = decoder.decodeCellToTlaEx(nextState.arena, cell) val expectedOutcome: TlaEx = - tla.atat(tla.int(1), tla.int(2), tla.int(2), tla.int(3)) + atat(colonGreater(int(1), int(2)) ? "i_to_i", colonGreater(int(2), int(3)) ? "i_to_i") + .typed(types, "i_to_i") assert(expectedOutcome == decodedEx) - // we cannot directly compare the outcome, as it comes in the same form as a record - // assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(decodedEx, funEx))) + val eq = eql(decodedEx, funEx) + .typed(BoolT1()) + assertTlaExAndRestore(rewriter, nextState.setRex(eq)) } test("decode statically empty fun") { - val domEx = tla.withType(tla.enumSet(), AnnotationParser.toTla(FinSetT(IntT()))) - val funEx = tla.funDef(tla.plus(tla.name("x"), tla.int(1)), tla.name("x"), domEx) + val domEx = enumSet() ? "I" + val funEx = funDef(plus(name("x") ? "i", int(1)) ? "i", name("x") ? "i", domEx) + .typed(types, "i_to_i") val state = new SymbState(funEx, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -148,20 +182,23 @@ class TestSymbStateDecoder extends RewriterBase { val decoder = new SymbStateDecoder(solverContext, rewriter) val decodedEx = decoder.decodeCellToTlaEx(nextState.arena, cell) // this is the standard outcome for an empty-domain function: {x \in {} |-> {}} - val expectedOutcome: TlaEx = tla.atat() + val expectedOutcome = funDef(name("x") ? "i", name("x") ? "i", enumSet() ? "I") + .typed(types, "i_to_i") assert(expectedOutcome == decodedEx) - // we cannot directly compare the outcome, as it comes in the same form as a record - // assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(decodedEx, funEx))) + assertTlaExAndRestore(rewriter, nextState.setRex(eql(decodedEx, funEx).typed(BoolT1()))) } test("decode dynamically empty fun") { // this domain is not empty at the arena level, but it is in every SMT model - def dynEmpty(left: TlaEx): TlaEx = { - tla.filter(tla.name("t"), left, tla.bool(false)) + def dynEmpty(left: BuilderEx): BuilderEx = { + filter(name("t") ? "i", left, bool(false) ? "b") } - val domEx = dynEmpty(tla.enumSet(tla.int(1))) - val funEx = tla.funDef(tla.plus(tla.name("x"), tla.int(1)), tla.name("x"), domEx) + val domEx = dynEmpty(enumSet(int(1)) ? "I") + .typed(types, "I") + // [ x \in {} |-> x + 1 ] + val funEx = funDef(plus(name("x") ? "i", int(1)) ? "i", name("x") ? "i", domEx) + .typed(types, "i_to_i") val state = new SymbState(funEx, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -169,17 +206,20 @@ class TestSymbStateDecoder extends RewriterBase { val cell = nextState.asCell val decoder = new SymbStateDecoder(solverContext, rewriter) val decodedEx = decoder.decodeCellToTlaEx(nextState.arena, cell) - // this is the standard outcome for an empty-domain function: {x \in {} |-> {}} - val expectedOutcome: TlaEx = tla.atat() + // this is the standard outcome for an empty-domain function: {x \in {} |-> x} + val expectedOutcome = funDef(name("x") ? "i", name("x") ? "i", enumSet() ? "I") + .typed(types, "i_to_i") assert(expectedOutcome == decodedEx) // we cannot directly compare the outcome, as it comes in the same form as a record - // assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(decodedEx, funEx))) + assertTlaExAndRestore(rewriter, nextState.setRex(eql(decodedEx, funEx).typed(BoolT1()))) } test("decode sequence") { val seqEx = - tla.withType(tla.tuple(tla.int(1), tla.int(2), tla.int(3), tla.int(4)), AnnotationParser.toTla(SeqT(IntT()))) - val subseqEx = tla.subseq(seqEx, tla.int(2), tla.int(3)) + tuple(int(1), int(2), int(3), int(4)) + .typed(types, "Qi") + val subseqEx = subseq(seqEx, int(2), int(3)) + .typed(types, "Qi") val state = new SymbState(subseqEx, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -187,13 +227,15 @@ class TestSymbStateDecoder extends RewriterBase { val cell = nextState.asCell val decoder = new SymbStateDecoder(solverContext, rewriter) val decodedEx = decoder.decodeCellToTlaEx(nextState.arena, cell) - val expected: TlaEx = tla.tuple(tla.int(2), tla.int(3)) + val expected = tuple(int(2), int(3)) + .typed(types, "Qi") assert(expected == decodedEx) } test("decode tuple") { val tupleEx: TlaEx = - tla.tuple(tla.int(1), tla.int(2), tla.int(3)) + tuple(int(1), int(2), int(3)) + .typed(types, "iii") val state = new SymbState(tupleEx, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -205,8 +247,9 @@ class TestSymbStateDecoder extends RewriterBase { } test("decode record") { - val recEx: TlaEx = - tla.enumFun(tla.str("a"), tla.int(1), tla.str("b"), tla.bool(true)) + val recEx = + enumFun(str("a"), int(1), str("b"), bool(true)) + .typed(types, "rib") val state = new SymbState(recEx, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterAction.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterAction.scala index 6518765dd5..6dd3c8aec9 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterAction.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterAction.scala @@ -10,7 +10,7 @@ import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestSymbStateRewriterAction extends RewriterBase { - test("""SE-PRIME: x' ~~> NameEx(x')""") { + test("""x' is rewritten to the binding of x'""") { val rewriter = create() arena.appendCell(IntT()) // the type finder is strict about unassigned types, so let's create a cell for x' val state = new SymbState(tla.prime(NameEx("x")), arena, Binding("x'" -> arena.topCell)) diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterAssignment.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterAssignment.scala index a6944722ae..3987661485 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterAssignment.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterAssignment.scala @@ -1,33 +1,33 @@ package at.forsyte.apalache.tla.bmcmt import at.forsyte.apalache.tla.bmcmt.types._ -import at.forsyte.apalache.tla.lir.convenience.tla -import at.forsyte.apalache.tla.lir.values.{TlaIntSet, TlaNatSet} +import at.forsyte.apalache.tla.lir.TypedPredefs._ import at.forsyte.apalache.tla.lir._ -import at.forsyte.apalache.tla.lir.oper.BmcOper -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.lir.convenience.tla._ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner -import scala.collection.immutable.{SortedMap, SortedSet, TreeMap} - /** * Tests for assignments. The assignments were at the core of Apalache 0.5.x. In Apalache 0.6.x, they are preprocessed * into Skolemizable existential quantifiers. We keep the tests for regression. */ @RunWith(classOf[JUnitRunner]) -class TestSymbStateRewriterAssignment extends RewriterBase with TestingPredefs { - private val set12: TlaEx = tla.enumSet(tla.int(1), tla.int(2)) - private val x_prime: TlaEx = tla.prime(tla.name("x")) - private val y_prime: TlaEx = tla.prime(tla.name("y")) - private val boundName: TlaEx = tla.name("t") - private val boolset: TlaEx = tla.enumSet(tla.bool(false), tla.bool(true)) - - test("""SE-IN-ASSIGN1(int): \E t \in {1, 2}: x' \in {t} ~~> TRUE and [x -> $C$k]""") { - val set = set12 - val assign = OperEx(BmcOper.skolem, tla.exists(boundName, set, tla.assign(x_prime, boundName))) - - val state = new SymbState(assign, arena, Binding()) +class TestSymbStateRewriterAssignment extends RewriterBase { + private val types = + Map("b" -> BoolT1(), "B" -> SetT1(BoolT1()), "i" -> IntT1(), "I" -> SetT1(IntT1()), "II" -> SetT1(SetT1(IntT1())), + "b_to_i" -> FunT1(BoolT1(), IntT1()), "b_TO_i" -> SetT1(FunT1(BoolT1(), IntT1())), + "i_to_b" -> FunT1(IntT1(), BoolT1()), "i_to_i" -> FunT1(IntT1(), IntT1()), + "i_TO_i" -> SetT1(FunT1(IntT1(), IntT1())), "ibI" -> TupT1(IntT1(), BoolT1(), SetT1(IntT1()))) + private val set12: TlaEx = enumSet(int(1), int(2)).typed(SetT1(IntT1())) + private val x_prime: TlaEx = prime(name("x") ? "i").typed(types, "i") + private val y_prime: TlaEx = prime(name("y") ? "i").typed(types, "i") + private val boundName: TlaEx = name("t").typed(IntT1()) + private val boolset: TlaEx = enumSet(bool(false), bool(true)).typed(SetT1(BoolT1())) + + test("""\E t \in {1, 2}: x' \in {t} ~~> TRUE and [x -> $C$k]""") { + val asgn = apalacheSkolem(exists(boundName, set12, assign(x_prime, boundName) ? "b") ? "b").typed(types, "b") + + val state = new SymbState(asgn, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -38,26 +38,26 @@ class TestSymbStateRewriterAssignment extends RewriterBase with TestingPredefs { assert(nextState.binding.contains("x'")) val boundCell = nextState.binding("x'") rewriter.push() - solverContext.assertGroundExpr(tla.eql(boundCell.toNameEx, tla.int(1))) + solverContext.assertGroundExpr(eql(boundCell.toNameEx ? "i", int(1)).typed(types, "b")) assert(solverContext.sat()) // ok rewriter.pop() rewriter.push() - solverContext.assertGroundExpr(tla.eql(boundCell.toNameEx, tla.int(2))) + solverContext.assertGroundExpr(eql(boundCell.toNameEx ? "i", int(2)).typed(types, "b")) assert(solverContext.sat()) // also possible rewriter.pop() rewriter.push() - solverContext.assertGroundExpr(tla.eql(boundCell.toNameEx, tla.int(3))) + solverContext.assertGroundExpr(eql(boundCell.toNameEx ? "i", int(3)).typed(types, "b")) assertUnsatOrExplain(rewriter, nextState) // should not be possible } - test("""SE-IN-ASSIGN1(int): assign in conjunction""") { - val and = - tla.and( - tla.assign(x_prime, tla.int(1)), - tla.assign(y_prime, tla.int(2)) - ) + test("""assign in conjunction""") { + val and1 = + and( + assign(x_prime, int(1)) ? "b", + assign(y_prime, int(2)) ? "b" + ).typed(types, "b") - val state = new SymbState(and, arena, Binding()) + val state = new SymbState(and1, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) val x_cell = nextState.binding("x'").toNameEx @@ -65,23 +65,24 @@ class TestSymbStateRewriterAssignment extends RewriterBase with TestingPredefs { assert(solverContext.sat()) // no contradiction introduced rewriter.push() - solverContext.assertGroundExpr(tla.eql(x_cell, tla.int(1))) - solverContext.assertGroundExpr(tla.eql(y_cell, tla.int(2))) + solverContext.assertGroundExpr(eql(x_cell, int(1)).typed(types, "b")) + solverContext.assertGroundExpr(eql(y_cell, int(2)).typed(types, "b")) assert(solverContext.sat()) // ok rewriter.pop() rewriter.push() - solverContext.assertGroundExpr(tla.neql(x_cell, tla.int(1))) + solverContext.assertGroundExpr(neql(x_cell, int(1)).typed(types, "b")) assertUnsatOrExplain(rewriter, nextState) // should not be possible rewriter.pop() rewriter.push() - solverContext.assertGroundExpr(tla.neql(y_cell, tla.int(2))) + solverContext.assertGroundExpr(neql(y_cell, int(2)).typed(types, "b")) assertUnsatOrExplain(rewriter, nextState) // should not be possible } - test("""SE-IN-ASSIGN1(int): x' \in {} ~~> FALSE""") { - val assign = OperEx(BmcOper.skolem, tla.exists(boundName, tla.enumSet(), tla.assign(x_prime, boundName))) + test("""\E t \in {}: x' = t ~~> FALSE""") { + val asgn = + apalacheSkolem(exists(boundName, enumSet() ? "I", assign(x_prime, boundName) ? "b") ? "b").typed(types, "b") - val state = new SymbState(assign, arena, Binding()) + val state = new SymbState(asgn, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { @@ -94,41 +95,41 @@ class TestSymbStateRewriterAssignment extends RewriterBase with TestingPredefs { } } - test("""SE-IN-ASSIGN1(int): \E t \in \in {t_2 \in {1}: FALSE}: x' \in {t} ~~> FALSE""") { + test("""\E t \in \in {t_2 \in {1}: FALSE}: x' \in {t} ~~> FALSE""") { // a regression test - def empty(set: TlaEx): TlaEx = { - tla.filter(tla.name("t_2"), set, tla.bool(false)) + def empty(set: BuilderEx): TlaEx = { + filter(name("t_2") ? "i", set ? "I", bool(false)).typed(types, "I") } - val assign = - OperEx(BmcOper.skolem, tla.exists(boundName, empty(tla.enumSet(tla.int(1))), tla.assign(x_prime, boundName))) + val asgn = + apalacheSkolem(exists(boundName, empty(enumSet(int(1))), assign(x_prime, boundName) ? "b") ? "b").typed(types, + "b") - val state = new SymbState(assign, arena, Binding()) + val state = new SymbState(asgn, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) // no contradiction should be introduced assert(solverContext.sat()) // the assignment gives us false - assertTlaExAndRestore(rewriter, nextState.setRex(tla.not(nextState.ex))) + assertTlaExAndRestore(rewriter, nextState.setRex(not(nextState.ex).typed(BoolT1()))) } - test("""SE-IN-ASSIGN1(int): x' \in {1} /\ x' = 1 ~~> TRUE and [x -> $C$k]""") { - val assign = tla.assign(x_prime, tla.int(1)) - val and = tla.and(assign, tla.eql(x_prime, tla.int(1))) + test("""x' \in {1} /\ x' = 1 ~~> TRUE and [x -> $C$k]""") { + val asgn = assign(x_prime, int(1)) + val and1 = and(asgn ? "b", eql(x_prime, int(1)) ? "b").typed(types, "b") - val state = new SymbState(and, arena, Binding()) + val state = new SymbState(and1, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) - val boundCell = - nextState.ex match { - case NameEx(name) => - assert(nextState.binding.toMap.size == 1) - assert(nextState.binding.contains("x'")) - nextState.binding("x'") + nextState.ex match { + case NameEx(_) => + assert(nextState.binding.toMap.size == 1) + assert(nextState.binding.contains("x'")) + nextState.binding("x'") - case _ => - fail("Unexpected rewriting result") - } + case _ => + fail("Unexpected rewriting result") + } assert(solverContext.sat()) // no contradiction introduced rewriter.push() @@ -136,15 +137,15 @@ class TestSymbStateRewriterAssignment extends RewriterBase with TestingPredefs { assert(solverContext.sat()) // ok rewriter.pop() rewriter.push() - solverContext.assertGroundExpr(tla.not(nextState.ex)) + solverContext.assertGroundExpr(not(nextState.ex).typed(BoolT1())) assert(!solverContext.sat()) } - test("""SE-IN-ASSIGN1(set): \E t \in {{1, 2}, {2, 3}}: x' \in {t} ~~> TRUE and [x -> $C$k]""") { - val set = tla.enumSet(set12, tla.enumSet(tla.int(2), tla.int(3))) - val assign = OperEx(BmcOper.skolem, tla.exists(boundName, set, tla.assign(x_prime, boundName))) + test("""\E t \in {{1, 2}, {2, 3}}: x' \in {t} ~~> TRUE and [x -> $C$k]""") { + val set = enumSet(set12, enumSet(int(2), int(3)) ? "I").typed(types, "II") + val asgn = apalacheSkolem(exists(boundName, set, assign(x_prime, boundName) ? "b") ? "b").typed(types, "b") - val state = new SymbState(assign, arena, Binding()) + val state = new SymbState(asgn, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) // no contradiction introduced @@ -158,42 +159,44 @@ class TestSymbStateRewriterAssignment extends RewriterBase with TestingPredefs { // may equal to {1, 2} rewriter.push() - val eq12 = tla.eql(boundCell.toNameEx, set12) + val eq12 = eql(boundCell.toNameEx ? "I", set12).typed(types, "b") val eqState12 = rewriter.rewriteUntilDone(nextState.setRex(eq12)) solverContext.assertGroundExpr(eqState12.ex) assert(solverContext.sat()) // ok rewriter.pop() // may equal to {2, 3} rewriter.push() - val eq23 = tla.eql(boundCell.toNameEx, tla.enumSet(tla.int(2), tla.int(3))) + val eq23 = eql(boundCell.toNameEx ? "I", enumSet(int(2), int(3)) ? "I").typed(types, "b") val eqState23 = rewriter.rewriteUntilDone(nextState.setRex(eq23)) solverContext.assertGroundExpr(eqState23.ex) assert(solverContext.sat()) // also possible rewriter.pop() // not equal to {1, 3} rewriter.push() - val eq13 = tla.eql(boundCell.toNameEx, tla.enumSet(tla.int(1), tla.int(3))) + val eq13 = eql(boundCell.toNameEx ? "I", enumSet(int(1), int(3)) ? "I").typed(types, "b") val eqState13 = rewriter.rewriteUntilDone(nextState.setRex(eq13)) solverContext.assertGroundExpr(eqState13.ex) assertUnsatOrExplain(rewriter, eqState13) // should not be possible } - test("""SE-IN-ASSIGN1(set): \E t \in {{1, 2}, {1+1, 2, 3}} \ {{2, 3}}: x' \in {t} ~~> TRUE and [x -> $C$k]""") { + test("""\E t \in {{1, 2}, {1+1, 2, 3}} \ {{2, 3}}: x' \in {t} ~~> TRUE and [x -> $C$k]""") { // equal elements in different sets mess up picking from a set def setminus(left: TlaEx, right: TlaEx): TlaEx = { // this is how Keramelizer translates setminus - tla.filter(tla.name("t_2"), left, tla.not(tla.eql(tla.name("t_2"), right))) + filter(name("t_2") ? "I", left ? "II", not(eql(name("t_2") ? "I", right ? "I") ? "b") ? "b") + .typed(types, "II") } - val set1to3 = tla.enumSet(tla.plus(tla.int(1), tla.int(1)), tla.int(2), tla.int(3)) - val set12 = tla.enumSet(tla.int(1), tla.int(2)) - val twoSets = tla.enumSet(set12, set1to3) - val set23 = tla.enumSet(tla.int(2), tla.int(3)) + val set1to3 = enumSet(plus(int(1), int(1)) ? "i", int(2), int(3)).typed(types, "I") + val set12 = enumSet(int(1), int(2)).typed(types, "I") + val twoSets = enumSet(set12, set1to3).typed(types, "II") + val set23 = enumSet(int(2), int(3)).typed(types, "I") val minus = setminus(twoSets, set23) - val assign = - OperEx(BmcOper.skolem, tla.exists(boundName, minus, tla.assign(x_prime, boundName))) + val asgn = + apalacheSkolem(exists(boundName, minus, assign(x_prime, boundName) ? "b") ? "b") + .typed(types, "b") - val state = new SymbState(assign, arena, Binding()) + val state = new SymbState(asgn, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) // no contradiction introduced @@ -206,40 +209,45 @@ class TestSymbStateRewriterAssignment extends RewriterBase with TestingPredefs { // may equal to {1, 2} rewriter.push() - val eq12 = tla.eql(boundCell.toNameEx, tla.enumSet(tla.int(1), tla.int(2))) + val eq12 = eql(boundCell.toNameEx ? "I", enumSet(int(1), int(2)) ? "I") + .typed(types, "b") val eqState12 = rewriter.rewriteUntilDone(nextState.setRex(eq12)) solverContext.assertGroundExpr(eqState12.ex) assert(solverContext.sat()) // ok rewriter.pop() // not equal to {1, 3} rewriter.push() - val eq13 = tla.eql(boundCell.toNameEx, tla.enumSet(tla.int(1), tla.int(3))) + val eq13 = eql(boundCell.toNameEx ? "I", enumSet(int(1), int(3)) ? "I") + .typed(types, "b") val eqState13 = rewriter.rewriteUntilDone(nextState.setRex(eq13)) solverContext.assertGroundExpr(eqState13.ex) assertUnsatOrExplain(rewriter, eqState13) // should not be possible rewriter.pop() // not equal to {2, 3} rewriter.push() - val eq23 = tla.eql(boundCell.toNameEx, tla.enumSet(tla.int(2), tla.int(3))) + val eq23 = eql(boundCell.toNameEx ? "I", enumSet(int(2), int(3)) ? "I") + .typed(types, "b") val eqState23 = rewriter.rewriteUntilDone(nextState.setRex(eq23)) solverContext.assertGroundExpr(eqState23.ex) assertUnsatOrExplain(rewriter, eqState23) // should not be possible rewriter.pop() // 2 is in the result rewriter.push() - val in23 = tla.in(tla.int(2), boundCell.toNameEx) + val in23 = in(int(2), boundCell.toNameEx ? "I") + .typed(types, "b") val inState23 = rewriter.rewriteUntilDone(nextState.setRex(in23)) solverContext.assertGroundExpr(inState23.ex) assert(solverContext.sat()) // should be possible rewriter.pop() } - test("""SE-IN-ASSIGN1(set): \E t \in SUBSET {1, 2}: x' \in {t} ~~> TRUE and [x -> $C$k]""") { - val set = tla.powSet(set12) - val assign = - OperEx(BmcOper.skolem, tla.exists(boundName, set, tla.assign(x_prime, boundName))) + test("""\E t \in SUBSET {1, 2}: x' \in {t} ~~> TRUE and [x -> $C$k]""") { + val set = powSet(set12).typed(types, "II") + val asgn = + apalacheSkolem(exists(boundName, set, assign(x_prime, boundName) ? "b") ? "b") + .typed(types, "b") - val state = new SymbState(assign, arena, Binding()) + val state = new SymbState(asgn, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) val boundCell = @@ -258,49 +266,55 @@ class TestSymbStateRewriterAssignment extends RewriterBase with TestingPredefs { assert(solverContext.sat()) // may equal to {1, 2} rewriter.push() - val eq12 = tla.eql(boundCell.toNameEx, set12) + val eq12 = eql(boundCell.toNameEx ? "I", set12) + .typed(types, "b") val eqState12 = rewriter.rewriteUntilDone(nextState.setRex(eq12)) solverContext.assertGroundExpr(eqState12.ex) assert(solverContext.sat()) // ok rewriter.pop() // may equal to {1} rewriter.push() - val eq1 = tla.eql(boundCell.toNameEx, tla.enumSet(tla.int(1))) + val eq1 = eql(boundCell.toNameEx ? "I", enumSet(int(1)) ? "I") + .typed(types, "b") val eqState1 = rewriter.rewriteUntilDone(nextState.setRex(eq1)) solverContext.assertGroundExpr(eqState1.ex) assert(solverContext.sat()) // ok rewriter.pop() // may equal to {2} rewriter.push() - val eq2 = tla.eql(boundCell.toNameEx, tla.enumSet(tla.int(2))) + val eq2 = eql(boundCell.toNameEx ? "I", enumSet(int(2)) ? "I") + .typed(types, "b") val eqState2 = rewriter.rewriteUntilDone(nextState.setRex(eq2)) solverContext.assertGroundExpr(eqState2.ex) assert(solverContext.sat()) // ok rewriter.pop() // may equal to {}, but this needs a type annotation rewriter.push() - val eqEmpty = tla.eql(boundCell.toNameEx, tla.withType(tla.enumSet(), AnnotationParser.toTla(FinSetT(IntT())))) + val eqEmpty = eql(boundCell.toNameEx ? "I", enumSet() ? "I") + .typed(types, "b") val eqStateEmpty = rewriter.rewriteUntilDone(nextState.setRex(eqEmpty)) solverContext.assertGroundExpr(eqStateEmpty.ex) assert(solverContext.sat()) // ok rewriter.pop() // not equal to {1, 2, 3} rewriter.push() - val eq13 = tla.eql(boundCell.toNameEx, tla.enumSet(tla.int(1), tla.int(2), tla.int(3))) + val eq13 = eql(boundCell.toNameEx ? "I", enumSet(int(1), int(2), int(3)) ? "I") + .typed(types, "b") val eqState13 = rewriter.rewriteUntilDone(nextState.setRex(eq13)) solverContext.assertGroundExpr(eqState13.ex) assertUnsatOrExplain(rewriter, eqState13) // should not be possible } - test("""SE-IN-ASSIGN1(fun): \E t \in {[x \in BOOLEAN |-> 0], [x2 \in BOOLEAN |-> 1]}: x' \in {t} ~~> TRUE""") { - val fun0 = tla.funDef(tla.int(0), tla.name("x2"), tla.booleanSet()) - val fun1 = tla.funDef(tla.int(1), tla.name("x3"), tla.booleanSet()) - val fun2 = tla.funDef(tla.int(2), tla.name("x4"), tla.booleanSet()) - val set = tla.enumSet(fun0, fun1) - val assign = - OperEx(BmcOper.skolem, tla.exists(boundName, set, tla.assign(x_prime, boundName))) + test("""\E t \in {[x \in BOOLEAN |-> 0], [x2 \in BOOLEAN |-> 1]}: x' \in {t} ~~> TRUE""") { + val fun0 = funDef(int(0), name("x2") ? "b", booleanSet() ? "B").typed(types, "b_to_i") + val fun1 = funDef(int(1), name("x3") ? "b", booleanSet() ? "B").typed(types, "b_to_i") + val fun2 = funDef(int(2), name("x4") ? "b", booleanSet() ? "B").typed(types, "b_to_i") + val set = enumSet(fun0, fun1).typed(types, "b_TO_i") + val asgn = + apalacheSkolem(exists(boundName, set, assign(x_prime, boundName) ? "b") ? "b") + .typed(types, "b") - val state = new SymbState(assign, arena, Binding()) + val state = new SymbState(asgn, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) // no contradiction introduced @@ -312,35 +326,35 @@ class TestSymbStateRewriterAssignment extends RewriterBase with TestingPredefs { // may equal to fun0 rewriter.push() - val eqFun0 = tla.eql(boundCell.toNameEx, fun0) + val eqFun0 = eql(boundCell.toNameEx ? "b_to_i", fun0).typed(types, "b") val eqStateFun0 = rewriter.rewriteUntilDone(nextState.setRex(eqFun0)) solverContext.assertGroundExpr(eqStateFun0.ex) assert(solverContext.sat()) // ok rewriter.pop() // may equal to fun1 rewriter.push() - val eqFun1 = tla.eql(boundCell.toNameEx, fun1) + val eqFun1 = eql(boundCell.toNameEx ? "b_to_i", fun1).typed(types, "b") val eqStateFun1 = rewriter.rewriteUntilDone(nextState.setRex(eqFun1)) solverContext.assertGroundExpr(eqStateFun1.ex) assert(solverContext.sat()) // also possible rewriter.pop() // not equal to fun2 rewriter.push() - val eqFun2 = tla.eql(boundCell.toNameEx, fun2) + val eqFun2 = eql(boundCell.toNameEx ? "b_to_i", fun2).typed(types, "b") val eqStateFun2 = rewriter.rewriteUntilDone(nextState.setRex(eqFun2)) solverContext.assertGroundExpr(eqStateFun2.ex) assertUnsatOrExplain(rewriter, eqStateFun2) // should not be possible } - test("""SE-IN-ASSIGN1(funset): \E t \in [BOOLEAN -> {0, 1}]: x' \in {t} ~~> TRUE""") { - val fun0 = tla.funDef(tla.int(0), tla.name("x"), tla.booleanSet()) - val fun1 = tla.funDef(tla.int(1), tla.name("x"), tla.booleanSet()) - val fun2 = tla.funDef(tla.int(2), tla.name("x"), tla.booleanSet()) - val set = tla.funSet(tla.booleanSet(), tla.enumSet(tla.int(0), tla.int(1))) - val assign = - OperEx(BmcOper.skolem, tla.exists(boundName, set, tla.assign(x_prime, boundName))) + test("""\E t \in [BOOLEAN -> {0, 1}]: x' \in {t} ~~> TRUE""") { + val fun0 = funDef(int(0), name("x") ? "b", booleanSet() ? "B").typed(types, "b_to_i") + val fun1 = funDef(int(1), name("x") ? "b", booleanSet() ? "B").typed(types, "b_to_i") + val fun2 = funDef(int(2), name("x") ? "b", booleanSet() ? "B").typed(types, "b_to_i") + val set = funSet(booleanSet() ? "I", enumSet(int(0), int(1)) ? "I").typed(types, "b_TO_i") + val asgn = + apalacheSkolem(exists(boundName, set, assign(x_prime, boundName) ? "b") ? "b").typed(types, "b") - val state = new SymbState(assign, arena, Binding()) + val state = new SymbState(asgn, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) val boundCell = @@ -359,33 +373,34 @@ class TestSymbStateRewriterAssignment extends RewriterBase with TestingPredefs { assert(solverContext.sat()) // may equal to fun0 rewriter.push() - val eqFun0 = tla.eql(boundCell.toNameEx, fun0) + val eqFun0 = eql(boundCell.toNameEx ? "b_to_i", fun0).typed(types, "b") val eqStateFun0 = rewriter.rewriteUntilDone(nextState.setRex(eqFun0)) solverContext.assertGroundExpr(eqStateFun0.ex) assert(solverContext.sat()) // ok rewriter.pop() // may equal to fun1 rewriter.push() - val eqFun1 = tla.eql(boundCell.toNameEx, fun1) + val eqFun1 = eql(boundCell.toNameEx ? "b_to_i", fun1).typed(types, "b") val eqStateFun1 = rewriter.rewriteUntilDone(nextState.setRex(eqFun1)) solverContext.assertGroundExpr(eqStateFun1.ex) assert(solverContext.sat()) // also possible rewriter.pop() // not equal to fun2 rewriter.push() - val eqFun2 = tla.eql(boundCell.toNameEx, fun2) + val eqFun2 = eql(boundCell.toNameEx ? "b_to_i", fun2).typed(types, "b") val eqStateFun2 = rewriter.rewriteUntilDone(nextState.setRex(eqFun2)) solverContext.assertGroundExpr(eqStateFun2.ex) assertUnsatOrExplain(rewriter, eqStateFun2) // should not be possible } - test("""SE-IN-ASSIGN1(funset): \E t \in [{} -> {0, 1}]: x' \in {t} ~~> FALSE""") { + test("""\E t \in [{} -> {0, 1}]: x' \in {t} ~~> FALSE""") { // regression - val set = tla.funSet(tla.enumSet(), tla.enumSet(tla.int(0), tla.int(1))) - val assign = - OperEx(BmcOper.skolem, tla.exists(boundName, set, tla.assign(x_prime, boundName))) + val set = funSet(enumSet() ? "I", enumSet(int(0), int(1)) ? "I").typed(types, "i_TO_i") + val asgn = + apalacheSkolem(exists(boundName, set, assign(x_prime, boundName) ? "b") ? "b") + .typed(types, "b") - val state = new SymbState(assign, arena, Binding()) + val state = new SymbState(asgn, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) // no contradiction introduced @@ -395,49 +410,53 @@ class TestSymbStateRewriterAssignment extends RewriterBase with TestingPredefs { assertTlaExAndRestore(rewriter, nextState) } - test("""SE-IN-ASSIGN1(funset): \E t \in [0..(5 - 1) -> BOOLEAN]: x' \in {t} ~~> TRUE""") { - val domain = tla.dotdot(tla.int(0), tla.minus(tla.int(5), tla.int(1))) - val set = tla.funSet(domain, boolset) - val assign = - OperEx(BmcOper.skolem, tla.exists(boundName, set, tla.assign(x_prime, boundName))) + test("""\E t \in [0..(5 - 1) -> BOOLEAN]: x' \in {t} ~~> TRUE""") { + val domain = dotdot(int(0), minus(int(5), int(1)) ? "i").typed(types, "I") + val set = funSet(domain, boolset).typed(types, "i_to_b") + val asgn = + apalacheSkolem(exists(boundName, set, assign(x_prime, boundName) ? "b") ? "b") + .typed(types, "b") - val state = new SymbState(assign, arena, Binding()) + val state = new SymbState(asgn, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) - val boundCell = - nextState.ex match { - case NameEx(name) => - assert(arena.cellTrue().toString == name) - assert(nextState.binding.toMap.size == 1) - assert(nextState.binding.contains("x'")) - nextState.binding("x'") + nextState.ex match { + case NameEx(name) => + assert(arena.cellTrue().toString == name) + assert(nextState.binding.toMap.size == 1) + assert(nextState.binding.contains("x'")) + nextState.binding("x'") - case _ => - fail("Unexpected rewriting result") - } + case _ => + fail("Unexpected rewriting result") + } } - test("""ASSIGN[funset with Nat]: \E t \in [0..4 -> Nat]: x' <- t""") { - val domain = tla.dotdot(tla.int(0), tla.int(4)) - val set = tla.funSet(domain, ValEx(TlaNatSet)) - val assign = - OperEx(BmcOper.skolem, tla.exists(boundName, set, tla.assign(x_prime, boundName))) + test("""\E t \in [0..4 -> Nat]: x' <- t""") { + val domain = dotdot(int(0), int(4)).typed(types, "I") + val set = funSet(domain, natSet() ? "I").typed(types, "i_TO_i") + val asgn = + apalacheSkolem(exists(boundName, set, assign(x_prime, boundName) ? "b") ? "b") + .typed(types, "b") - val state = new SymbState(assign, arena, Binding()) + val state = new SymbState(asgn, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) assert(rewriter.solverContext.sat()) val x = nextState.binding("x'") - assertTlaExAndRestore(rewriter, nextState.setRex(tla.ge(tla.appFun(x.toNameEx, tla.int(1)), tla.int(0)))) + val pred = ge(appFun(x.toNameEx ? "i_to_i", int(1)) ? "i", int(0)) + .typed(types, "b") + assertTlaExAndRestore(rewriter, nextState.setRex(pred)) } - test("""ASSIGN[funset with Int]: \E t \in [0..4 -> Int]: x' <- t""") { - val domain = tla.dotdot(tla.int(0), tla.int(4)) - val set = tla.funSet(domain, ValEx(TlaIntSet)) - val assign = - OperEx(BmcOper.skolem, tla.exists(boundName, set, tla.assign(x_prime, boundName))) + test("""\E t \in [0..4 -> Int]: x' <- t""") { + val domain = dotdot(int(0), int(4)).typed(types, "I") + val set = funSet(domain, intSet() ? "I").typed(types, "i_TO_i") + val asgn = + apalacheSkolem(exists(boundName, set, assign(x_prime, boundName) ? "b") ? "b") + .typed(types, "b") - val state = new SymbState(assign, arena, Binding()) + val state = new SymbState(asgn, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) assert(rewriter.solverContext.sat()) @@ -445,16 +464,18 @@ class TestSymbStateRewriterAssignment extends RewriterBase with TestingPredefs { } // the model checker will never meet such an expression, as it will be optimized into several existentials by ExprOptimizer - test("""SE-IN-ASSIGN1(tuple): \E t \in {<<1, FALSE, {1, 3}>>, <<2, TRUE, {4}>>}: x' \in {t}""") { - val set1 = tla.enumSet(tla.int(1), tla.int(3)) - val tuple1 = tla.tuple(tla.int(1), tla.bool(false), set1) - val set2 = tla.enumSet(tla.int(4)) - val tuple2 = tla.tuple(tla.int(2), tla.bool(true), set2) - val set = tla.enumSet(tuple1, tuple2) - val assign = - OperEx(BmcOper.skolem, tla.exists(boundName, set, tla.assign(x_prime, boundName))) - - val state = new SymbState(assign, arena, Binding()) + test("""\E t \in {<<1, FALSE, {1, 3}>>, <<2, TRUE, {4}>>}: x' = t""") { + val set1 = enumSet(int(1), int(3)).typed(types, "I") + val tuple1 = tuple(int(1), bool(false), set1).typed(types, "ibI") + val set2 = enumSet(int(4)).typed(types, "I") + val tuple2 = tuple(int(2), bool(true), set2).typed(types, "ibI") + val set = enumSet(tuple1, tuple2).typed(SetT1(types("ibI"))) + val asgn = + apalacheSkolem(exists(name("t") ? "ibI", set, + assign(prime(name("x") ? "ibI") ? "ibI", name("t") ? "ibI") ? "b") ? "b") + .typed(types, "b") + + val state = new SymbState(asgn, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { @@ -463,76 +484,13 @@ class TestSymbStateRewriterAssignment extends RewriterBase with TestingPredefs { assert(TupleT(List(IntT(), BoolT(), FinSetT(IntT()))) == cell.cellType) val membershipTest = - tla.and(tla.in(tla.appFun(x_prime, tla.int(1)), set12), tla.in(tla.appFun(x_prime, tla.int(2)), boolset), - tla.in(tla.appFun(x_prime, tla.int(3)), tla.enumSet(set1, set2))) /// - - assertTlaExAndRestore(rewriter, nextState.setRex(membershipTest)) - - case _ => - fail("Unexpected rewriting result") - } - } - - // the model checker will never meet such an expression, as it will be optimized into several existentials by ExprOptimizer - test( - """SE-IN-ASSIGN1(record): \E t \in {{"a" -> 1, "b" -> FALSE}, {"a" -> 2, "b" -> TRUE, "c" -> {3, 4}}}: x' \in {t}""") { - val annotation = AnnotationParser.toTla(RecordT(SortedMap("a" -> IntT(), "b" -> BoolT(), "c" -> FinSetT(IntT())))) - // records in a set can have different sets of keys, although the field types should be compatible for each field - val record1 = tla.enumFun(tla.str("a"), tla.int(1), tla.str("b"), tla.bool(false)) - val set34 = tla.enumSet(tla.int(3), tla.int(4)) - val record2 = tla.enumFun(tla.str("a"), tla.int(2), tla.str("b"), tla.bool(true), tla.str("c"), set34) - val recordSet = tla.enumSet(tla.withType(record1, annotation), record2) - val assign = - OperEx(BmcOper.skolem, tla.exists(boundName, recordSet, tla.assign(x_prime, boundName))) - - val state = new SymbState(assign, arena, Binding()) - val rewriter = create() - rewriter.typeFinder.inferAndSave(assign) // trigger type inference manually - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case NameEx(_) => - val cell = nextState.binding("x'") - // x' is assigned a record from recordSet - assert(cell.cellType.isInstanceOf[RecordT]) - assert( - cell.cellType.asInstanceOf[RecordT].fields - == TreeMap("a" -> IntT(), "b" -> BoolT(), "c" -> FinSetT(IntT()))) - - val a_of_x_prime = tla.appFun(x_prime, tla.str("a")) - val b_of_x_prime = tla.appFun(x_prime, tla.str("b")) - val c_of_x_prime = tla.appFun(x_prime, tla.str("c")) - val membershipTest = - tla.and(tla.in(a_of_x_prime, set12), tla.in(b_of_x_prime, boolset)) - // interestingly, we cannot expect that x'.c \in { 3, 4 }, - // as the value of the field c is unknown for the first record - // tla.in(c_of_x_prime, tla.enumSet(set34)) + and(in(appFun(prime(name("x") ? "ibI") ? "ibI", int(1)) ? "i", set12) ? "b", + in(appFun(prime(name("x") ? "ibI") ? "ibI", int(2)) ? "i", boolset) ? "b", + in(appFun(prime(name("x") ? "ibI") ? "ibI", int(3)) ? "i", enumSet(set1, set2) ? "II") ? "b") + .typed(types, "b") assertTlaExAndRestore(rewriter, nextState.setRex(membershipTest)) - // if we assume that result["a"] = 2, we should get result["b"] = TRUE, and result["c"] = { 3, 4 } - rewriter.push() - val assumption = tla.eql(a_of_x_prime, tla.int(2)) - val assumState = assumeTlaEx(rewriter, nextState.setRex(assumption)) - - val bIsTrue = tla.eql(b_of_x_prime, tla.bool(true)) - assertTlaExAndRestore(rewriter, assumState.setRex(bIsTrue)) - val cIsSet34 = tla.eql(c_of_x_prime, set34) - assertTlaExAndRestore(rewriter, assumState.setRex(cIsSet34)) - rewriter.pop() - - // if we assume that result["a"] = 1, we should get DOMAIN result = {"a", "b"} - rewriter.push() - val assumption2 = tla.eql(a_of_x_prime, tla.int(1)) - val assumeState2 = assumeTlaEx(rewriter, nextState.setRex(assumption2)) - val (newArena, expectedDom) = - rewriter.recordDomainCache.getOrCreate(assumeState2.arena, (SortedSet("a", "b"), SortedSet("c"))) - val domEq = tla.eql(expectedDom.toNameEx, tla.dom(x_prime)) - assertTlaExAndRestore(rewriter, assumeState2.setArena(newArena).setRex(domEq)) - // and check that the record equals to the expected one - val eq = tla.eql(tla.withType(record1, annotation), x_prime) - assertTlaExAndRestore(rewriter, assumeState2.setRex(eq)) - rewriter.pop() - case _ => fail("Unexpected rewriting result") } diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterBool.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterBool.scala index b27632f739..4bd924efbf 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterBool.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterBool.scala @@ -1,13 +1,10 @@ package at.forsyte.apalache.tla.bmcmt import at.forsyte.apalache.tla.bmcmt.SymbStateRewriter.NoRule -import at.forsyte.apalache.tla.bmcmt.types.eager.TrivialTypeFinder -import at.forsyte.apalache.tla.bmcmt.types.{AnnotationParser, BoolT, FinSetT, IntT} +import at.forsyte.apalache.tla.bmcmt.types.{BoolT, FinSetT, IntT} +import at.forsyte.apalache.tla.lir.TypedPredefs._ import at.forsyte.apalache.tla.lir._ import at.forsyte.apalache.tla.lir.convenience.tla -import at.forsyte.apalache.tla.lir.oper.{BmcOper, TlaBoolOper, TlaOper} -import at.forsyte.apalache.tla.lir.values.{TlaBool, TlaBoolSet, TlaNatSet} -import at.forsyte.apalache.tla.lir.UntypedPredefs._ import org.junit.runner.RunWith import org.scalatest.BeforeAndAfterEach import org.scalatest.junit.JUnitRunner @@ -19,6 +16,7 @@ class TestSymbStateRewriterBool extends RewriterBase with TestingPredefs with Be private var y: ArenaCell = new ArenaCell(100001, IntT()) private var set: ArenaCell = new ArenaCell(100001, FinSetT(IntT())) private var xyBinding = Binding() + private val boolTypes = Map("b" -> BoolT1(), "i" -> IntT1(), "I" -> SetT1(IntT1())) override def beforeEach() { super.beforeEach() @@ -33,11 +31,11 @@ class TestSymbStateRewriterBool extends RewriterBase with TestingPredefs with Be } test("FALSE ~~> $C$0") { - val ex = ValEx(TlaBool(false)) + val ex = tla.bool(false).typed() val state = new SymbState(ex, arena, Binding()) create().rewriteOnce(state) match { case SymbStateRewriter.Continue(nextState) => - val expected = NameEx("$C$0") + val expected = NameEx("$C$0")(Untyped()) assert(expected == nextState.ex) assert(state.arena == nextState.arena) @@ -47,11 +45,11 @@ class TestSymbStateRewriterBool extends RewriterBase with TestingPredefs with Be } test("TRUE ~~> $C$1") { - val ex = ValEx(TlaBool(true)) + val ex = tla.bool(true).typed() val state = new SymbState(ex, arena, Binding()) create().rewriteOnce(state) match { case SymbStateRewriter.Continue(nextState) => - val expected = NameEx("$C$1") + val expected = NameEx("$C$1")(Untyped()) assert(expected == nextState.ex) assert(state.arena == nextState.arena) @@ -60,11 +58,12 @@ class TestSymbStateRewriterBool extends RewriterBase with TestingPredefs with Be } } - test("SE-SET-BOOLEAN: BOOLEAN ~~> c_BOOLEAN") { - val state = new SymbState(ValEx(TlaBoolSet), arena, Binding()) + test("BOOLEAN ~~> c_BOOLEAN") { + val boolset = tla.booleanSet().typed(SetT1(BoolT1())) + val state = new SymbState(boolset, arena, Binding()) create().rewriteOnce(state) match { case SymbStateRewriter.Continue(nextState) => - val expected = NameEx("$C$2") + val expected = NameEx("$C$2")(Untyped()) assert(expected == nextState.ex) assert(state.arena == nextState.arena) @@ -73,20 +72,24 @@ class TestSymbStateRewriterBool extends RewriterBase with TestingPredefs with Be } } - test("SE-BOOL-IMPL: x => y ~~> ~x \\/ y") { + test("x => y ~~> ~x \\/ y") { // outside of KerA+, should be handled by Keramelizer and Normalizer - val ex = tla.impl(tla.name("x"), tla.name("y")) + val ex = tla + .impl(tla.name("x") ? "b", tla.name("y") ? "b") + .typed(boolTypes, "b") val state = new SymbState(ex, arena, xyBinding) assert(NoRule() == create().rewriteOnce(state)) } - test("SE-BOOL-EQUIV: x <=> y") { + test("x <=> y") { // outside of KerA+, should be handled by Keramelizer and Normalizer arena = arena.appendCell(BoolT()) val left = arena.topCell arena = arena.appendCell(BoolT()) val right = arena.topCell - val ex = tla.equiv(left.toNameEx, right.toNameEx) + val ex = tla + .equiv(left.toNameEx ? "b", right.toNameEx ? "b") + .typed(boolTypes, "b") val state = new SymbState(ex, arena, xyBinding) assert(NoRule() == create().rewriteOnce(state)) } @@ -94,35 +97,42 @@ class TestSymbStateRewriterBool extends RewriterBase with TestingPredefs with Be test("""IF-THEN-ELSE with \E: IF \E i \in {}: x' \in {i} THEN x' ELSE 0""") { // this tricky test comes from Bakery, where an assignment is made in one branch of a conjunction val exists = - OperEx(BmcOper.skolem, - tla.exists(tla.name("i"), tla.withType(tla.enumSet(), AnnotationParser.toTla(FinSetT(IntT()))), - tla.in(tla.prime(tla.name("x")), tla.enumSet(tla.name("i"))))) - val ite = tla.ite(exists, tla.prime(tla.name("x")), tla.int(0)) + tla + .apalacheSkolem(tla.exists(tla.name("i") ? "i", tla.enumSet() ? "I", + tla.in(tla.prime(tla.name("x") ? "i") ? "i", tla.enumSet(tla.name("i") ? "i") ? "I") ? "b") ? "b") + .typed(boolTypes, "b") + val ite = tla + .ite(exists, tla.prime(tla.name("x") ? "i") ? "i", tla.int(0)) + .typed(boolTypes, "b") val state = new SymbState(ite, arena, Binding()) - val rewriter = new SymbStateRewriterImpl(solverContext, new TrivialTypeFinder()) + val rewriter = new SymbStateRewriterImpl(solverContext) var nextState = rewriter.rewriteUntilDone(state) assert(solverContext.sat()) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.int(0), nextState.ex))) + val eq = tla.eql(tla.int(0), nextState.ex).typed(BoolT1()) + assertTlaExAndRestore(rewriter, nextState.setRex(eq)) } - test("""SE-BOOL-NEG9: ~c_i ~~> b_new""") { + test("""~c_i ~~> b_new""") { arena = arena.appendCell(BoolT()) val cell = arena.topCell - val ex = OperEx(TlaBoolOper.not, cell.toNameEx) + val ex = tla.not(cell.toNameEx ? "b").typed(boolTypes, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() rewriter.rewriteOnce(state) match { case SymbStateRewriter.Continue(nextState) => nextState.ex match { case NameEx(name) => - solverContext.assertGroundExpr(OperEx(TlaOper.eq, cell.toNameEx, arena.cellFalse().toNameEx)) + val eq = tla + .eql(cell.toNameEx ? "b", arena.cellFalse().toNameEx ? "b") + .typed(boolTypes, "b") + solverContext.assertGroundExpr(eq) solverContext.assertGroundExpr(nextState.ex) rewriter.push() assert(solverContext.sat()) rewriter.pop() - solverContext.assertGroundExpr(OperEx(TlaBoolOper.not, nextState.ex)) + solverContext.assertGroundExpr(tla.not(nextState.ex).typed(BoolT1())) assert(!solverContext.sat()) case _ => @@ -134,43 +144,60 @@ class TestSymbStateRewriterBool extends RewriterBase with TestingPredefs with Be } } - test("""SE-BOOL-NEG9: ~x ~~> TRUE""") { - val ex = OperEx(TlaBoolOper.not, NameEx("x")) + test("""~x ~~> TRUE""") { + val ex = tla + .not(tla.name("x") ? "b") + .typed(boolTypes, "b") val binding = Binding("x" -> arena.cellFalse()) val state = new SymbState(ex, arena, binding) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(nextState.ex, tla.bool(true)))) + assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(nextState.ex, tla.bool(true)).typed(BoolT1()))) } test("""FALSE = TRUE ~~> FALSE""") { - val ex = tla.eql(arena.cellFalse().toNameEx, arena.cellTrue().toNameEx) + val ex = tla + .eql(arena.cellFalse().toNameEx ? "b", arena.cellTrue().toNameEx ? "b") + .typed(boolTypes, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(nextState.ex, arena.cellFalse().toNameEx))) + val eq = tla + .eql(nextState.ex, arena.cellFalse().toNameEx ? "b") + .typed(boolTypes, "b") + assertTlaExAndRestore(rewriter, nextState.setRex(eq)) } test("""x = TRUE ~~> FALSE when x = FALSE""") { - val ex = tla.eql(tla.name("x"), tla.bool(true)) + val ex = tla + .eql(tla.name("x") ? "b", tla.bool(true) ? "b") + .typed(boolTypes, "b") val binding = Binding("x" -> arena.cellFalse()) val state = new SymbState(ex, arena, binding) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(nextState.ex, tla.bool(false)))) + val eq = tla.eql(nextState.ex, tla.bool(false)).typed(BoolT1()) + assertTlaExAndRestore(rewriter, nextState.setRex(eq)) } test("""~(x = TRUE) ~~> TRUE when x = FALSE""") { - val ex = tla.not(tla.eql(tla.name("x"), tla.bool(true))) + val ex = tla + .not(tla.eql(tla.name("x") ? "b", tla.bool(true)) ? "b") + .typed(boolTypes, "b") val binding = Binding("x" -> arena.cellFalse()) val state = new SymbState(ex, arena, binding) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(nextState.ex, tla.bool(true)))) + val eq = tla + .eql(nextState.ex, tla.bool(true)) + .typed(boolTypes, "b") + assertTlaExAndRestore(rewriter, nextState.setRex(eq)) } - test("""SE-AND1: FALSE /\ TRUE ~~> $B$0""") { - val ex = OperEx(TlaBoolOper.and, ValEx(TlaBool(false)), ValEx(TlaBool(true))) + test("""FALSE /\ TRUE ~~> $B$0""") { + val ex = tla + .and(tla.bool(false), tla.bool(true)) + .typed(BoolT1()) val state = new SymbState(ex, arena, Binding()) create().rewriteOnce(state) match { case SymbStateRewriter.Continue(nextState) => @@ -182,13 +209,15 @@ class TestSymbStateRewriterBool extends RewriterBase with TestingPredefs with Be } } - test("""SE-AND4: c_1 /\ c_2 ~~> b_new""") { + test("""c_1 /\ c_2 ~~> b_new""") { arena = arena.appendCell(BoolT()) val c1 = arena.topCell arena = arena.appendCell(BoolT()) val c2 = arena.topCell - val ex = OperEx(TlaBoolOper.and, c1.toNameEx, c2.toNameEx) + val ex = tla + .and(c1.toNameEx ? "b", c2.toNameEx ? "b") + .typed(boolTypes, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() rewriter.rewriteOnce(state) match { @@ -198,12 +227,21 @@ class TestSymbStateRewriterBool extends RewriterBase with TestingPredefs with Be assert(solverContext.sat()) solverContext.assertGroundExpr(nextState.ex) rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, c1.toNameEx, arena.cellFalse().toNameEx)) + val eq1 = tla + .eql(c1.toNameEx ? "b", arena.cellFalse().toNameEx ? "b") + .typed(boolTypes, "b") + solverContext.assertGroundExpr(eq1) assert(!solverContext.sat()) rewriter.pop() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, c1.toNameEx, arena.cellTrue().toNameEx)) + val eq2 = tla + .eql(c1.toNameEx ? "b", arena.cellTrue().toNameEx ? "") + .typed(boolTypes, "b") + solverContext.assertGroundExpr(eq2) assert(solverContext.sat()) - solverContext.assertGroundExpr(OperEx(TlaOper.eq, c2.toNameEx, arena.cellTrue().toNameEx)) + val eq3 = tla + .eql(c2.toNameEx ? "b", arena.cellTrue().toNameEx ? "b") + .typed(boolTypes, "b") + solverContext.assertGroundExpr(eq3) assert(solverContext.sat()) case _ => @@ -215,8 +253,8 @@ class TestSymbStateRewriterBool extends RewriterBase with TestingPredefs with Be } } - test("""SE-OR4: empty \/ ~~> $B$0""") { - val ex = OperEx(TlaBoolOper.or) + test("""empty \/ ~~> $B$0""") { + val ex = tla.or().typed(BoolT1()) val state = new SymbState(ex, arena, Binding()) create().rewriteOnce(state) match { case SymbStateRewriter.Continue(nextState) => @@ -228,8 +266,8 @@ class TestSymbStateRewriterBool extends RewriterBase with TestingPredefs with Be } } - test("""SE-AND4: empty /\ ~~> $B$1""") { - val ex = OperEx(TlaBoolOper.and) + test("""empty /\ ~~> $B$1""") { + val ex = tla.and().typed(BoolT1()) val state = new SymbState(ex, arena, Binding()) create().rewriteOnce(state) match { case SymbStateRewriter.Continue(nextState) => @@ -241,8 +279,8 @@ class TestSymbStateRewriterBool extends RewriterBase with TestingPredefs with Be } } - test("""SE-OR1: FALSE \/ TRUE ~~> $B$1""") { - val ex = OperEx(TlaBoolOper.or, ValEx(TlaBool(false)), ValEx(TlaBool(true))) + test("""FALSE \/ TRUE ~~> $B$1""") { + val ex = tla.or(tla.bool(false), tla.bool(true)).typed(BoolT1()) val state = new SymbState(ex, arena, Binding()) create().rewriteOnce(state) match { case SymbStateRewriter.Continue(nextState) => @@ -254,24 +292,32 @@ class TestSymbStateRewriterBool extends RewriterBase with TestingPredefs with Be } } - test("""SE-OR5: c_1 \/ c_2 ~~> b_new""") { + test("""c_1 \/ c_2 ~~> b_new""") { arena = arena.appendCell(BoolT()) val left = arena.topCell arena = arena.appendCell(BoolT()) val right = arena.topCell - val ex = OperEx(TlaBoolOper.or, left.toNameEx, right.toNameEx) + val ex = tla + .or(left.toNameEx ? "b", right.toNameEx ? "b") + .typed(boolTypes, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() rewriter.rewriteOnce(state) match { case SymbStateRewriter.Continue(nextState) => nextState.ex match { case NameEx(name) => - solverContext.assertGroundExpr(OperEx(TlaOper.eq, left.toNameEx, arena.cellFalse().toNameEx)) + val eq1 = tla + .eql(left.toNameEx ? "b", arena.cellFalse().toNameEx ? "b") + .typed(boolTypes, "b") + solverContext.assertGroundExpr(eq1) solverContext.assertGroundExpr(nextState.ex) rewriter.push() assert(solverContext.sat()) - solverContext.assertGroundExpr(OperEx(TlaOper.eq, right.toNameEx, arena.cellFalse().toNameEx)) + val eq2 = tla + .eql(right.toNameEx ? "b", arena.cellFalse().toNameEx ? "b") + .typed(boolTypes, "b") + solverContext.assertGroundExpr(eq2) assert(!solverContext.sat()) case _ => @@ -283,13 +329,15 @@ class TestSymbStateRewriterBool extends RewriterBase with TestingPredefs with Be } } - test("""SE-BOOL-NE1: ~($B$1 = $B$2) ~~> $B$3""") { + test("""~($B$1 = $B$2) ~~> $B$3""") { arena = arena.appendCell(BoolT()) val left = arena.topCell arena = arena.appendCell(BoolT()) val right = arena.topCell - val ex = tla.not(tla.eql(left.toNameEx, right.toNameEx)) + val ex = tla + .not(tla.eql(left.toNameEx ? "b", right.toNameEx ? "b") ? "b") + .typed(boolTypes, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -299,23 +347,41 @@ class TestSymbStateRewriterBool extends RewriterBase with TestingPredefs with Be rewriter.push() // both false assert(solverContext.sat()) - solverContext.assertGroundExpr(OperEx(TlaOper.eq, left.toNameEx, arena.cellFalse().toNameEx)) + val eq1 = tla + .eql(left.toNameEx ? "b", arena.cellFalse().toNameEx ? "b") + .typed(boolTypes, "b") + solverContext.assertGroundExpr(eq1) assert(solverContext.sat()) rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, right.toNameEx, arena.cellFalse().toNameEx)) + val eq2 = tla + .eql(right.toNameEx ? "b", arena.cellFalse().toNameEx ? "b") + .typed(boolTypes, "b") + solverContext.assertGroundExpr(eq2) assert(!solverContext.sat()) rewriter.pop() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, right.toNameEx, arena.cellTrue().toNameEx)) + val eq3 = tla + .eql(right.toNameEx ? "b", arena.cellTrue().toNameEx ? "b") + .typed(boolTypes, "b") + solverContext.assertGroundExpr(eq3) assert(solverContext.sat()) rewriter.pop() // both true - solverContext.assertGroundExpr(OperEx(TlaOper.eq, left.toNameEx, arena.cellTrue().toNameEx)) + val eq4 = tla + .eql(left.toNameEx ? "b", arena.cellTrue().toNameEx ? "b") + .typed(boolTypes, "b") + solverContext.assertGroundExpr(eq4) assert(solverContext.sat()) rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, right.toNameEx, arena.cellTrue().toNameEx)) + val eq5 = tla + .eql(right.toNameEx ? "b", arena.cellTrue().toNameEx ? "b") + .typed(boolTypes, "b") + solverContext.assertGroundExpr(eq5) assert(!solverContext.sat()) rewriter.pop() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, right.toNameEx, arena.cellFalse().toNameEx)) + val eq6 = tla + .eql(right.toNameEx ? "b", arena.cellFalse().toNameEx ? "b") + .typed(boolTypes, "b") + solverContext.assertGroundExpr(eq6) assert(solverContext.sat()) case _ => @@ -323,16 +389,21 @@ class TestSymbStateRewriterBool extends RewriterBase with TestingPredefs with Be } } - test("""SE-EX2: \E x \in {}: TRUE ~~> FALSE""") { - val ex = tla.exists(tla.name("x"), tla.enumSet(), tla.bool(true)) + test("""\E x \in {}: TRUE ~~> FALSE""") { + val ex = tla + .exists(tla.name("x") ? "i", tla.enumSet() ? "I", tla.bool(true)) + .typed(boolTypes, "b") val state = new SymbState(ex, arena, Binding()) val nextState = create().rewriteUntilDone(state) assert(arena.cellFalse().toNameEx == nextState.ex) } - test("""SE-EX3: \E x \in {1, 2, 3}: x = 2 ~~> $B$k""") { + test("""\E x \in {1, 2, 3}: x = 2 ~~> $B$k""") { + val set123 = tla.enumSet(tla.int(1), tla.int(2), tla.int(3)).typed(SetT1(IntT1())) val ex = - tla.exists(tla.name("x"), tla.enumSet(tla.int(1), tla.int(2), tla.int(3)), tla.eql(tla.int(2), tla.name("x"))) + tla + .exists(tla.name("x") ? "i", set123, tla.eql(tla.int(2), tla.name("x") ? "i") ? "b") + .typed(boolTypes, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -341,18 +412,23 @@ class TestSymbStateRewriterBool extends RewriterBase with TestingPredefs with Be solverContext.assertGroundExpr(nextState.ex) assert(solverContext.sat()) rewriter.pop() - solverContext.assertGroundExpr(tla.not(nextState.ex)) + solverContext.assertGroundExpr(tla.not(nextState.ex).typed(BoolT1())) assertUnsatOrExplain(rewriter, nextState) } /** Jure, 9.12.19: Why should this throw? */ - test("""SE-EX: \E x \in {1, 2}: y' = x ~~> 2 assignments, regression""") { + test("""\E x \in {1, 2}: y' := x ~~> 2 assignments, regression""") { + val set12 = tla + .enumSet(tla.int(1), tla.int(2)) + .typed(SetT1(IntT1())) // an assignment inside an existential quantifier is tricky, as we can multiple values to variables - val ex = tla.exists( - tla.name("x"), - tla.enumSet(tla.int(1), tla.int(2)), - tla.assignPrime(tla.name("y"), tla.name("x")) - ) + val ex = tla + .exists( + tla.name("x") ? "i", + set12, + tla.assign(tla.prime(tla.name("y") ? "i") ? "i", tla.name("x") ? "i") ? "b" + ) + .typed(boolTypes, "b") //// val state = new SymbState(ex, arena, Binding()) val rewriter = create() @@ -365,25 +441,34 @@ class TestSymbStateRewriterBool extends RewriterBase with TestingPredefs with Be assert(solverContext.sat()) rewriter.pop() rewriter.push() - solverContext.assertGroundExpr(tla.not(nextState.ex)) + solverContext.assertGroundExpr(tla.not(nextState.ex).typed(BoolT1())) assertUnsatOrExplain(rewriter, nextState) rewriter.pop() rewriter.push() solverContext.assertGroundExpr(nextState.ex) - solverContext.assertGroundExpr(tla.eql(tla.int(1), nextState.binding("y'").toNameEx)) + val eq1 = tla + .eql(tla.int(1), nextState.binding("y'").toNameEx ? "i") + .typed(boolTypes, "b") + solverContext.assertGroundExpr(eq1) assert(solverContext.sat()) rewriter.pop() rewriter.push() solverContext.assertGroundExpr(nextState.ex) - solverContext.assertGroundExpr(tla.eql(tla.int(2), nextState.binding("y'").toNameEx)) + val eq2 = tla + .eql(tla.int(2), nextState.binding("y'").toNameEx ? "i") + .typed(boolTypes, "b") + solverContext.assertGroundExpr(eq2) assert(solverContext.sat()) rewriter.pop() } } - test("""SE-EX3: \E x \in {1, 2, 3}: x > 4 ~~> $B$k""") { + test("""\E x \in {1, 2, 3}: x > 4 ~~> $B$k""") { + val set123 = tla.enumSet(tla.int(1), tla.int(2), tla.int(3)).typed(SetT1(IntT1())) val ex = - tla.exists(tla.name("x"), tla.enumSet(tla.int(1), tla.int(2), tla.int(3)), tla.gt(tla.name("x"), tla.int(4))) + tla + .exists(tla.name("x") ? "i", set123, tla.gt(tla.name("x") ? "i", tla.int(4)) ? "b") + .typed(boolTypes, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -392,35 +477,42 @@ class TestSymbStateRewriterBool extends RewriterBase with TestingPredefs with Be solverContext.assertGroundExpr(nextState.ex) assertUnsatOrExplain(rewriter, nextState) rewriter.pop() - solverContext.assertGroundExpr(tla.not(nextState.ex)) + solverContext.assertGroundExpr(tla.not(nextState.ex).typed(BoolT1())) assert(solverContext.sat()) } - test("""SE-EX: \E x \in {1} \ {1}: x > 4, regression""") { + test("""\E x \in {t \in {1}: FALSE}: x > 4, regression""") { def dynEmpty(left: TlaEx): TlaEx = { - tla.filter(tla.name("t"), left, tla.bool(false)) + tla + .filter(tla.name("t") ? "i", left ? "I", tla.bool(false)) + .typed(boolTypes, "I") } + val emptySet = dynEmpty(tla.enumSet(tla.int(1)).typed(SetT1(IntT1()))) + val pred = tla.gt(tla.name("x") ? "i", tla.int(4)).typed(boolTypes, "b") val ex = - OperEx(BmcOper.skolem, - tla.exists(tla.name("x"), dynEmpty(tla.enumSet(tla.int(1))), tla.gt(tla.name("x"), tla.int(4)))) + tla + .apalacheSkolem(tla.exists(tla.name("x") ? "i", emptySet, pred) ? "b") + .typed(boolTypes, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) assert(solverContext.sat()) // regression test, the buggy implementation failed here - assertTlaExAndRestore(rewriter, nextState.setRex(tla.not(nextState.ex))) // E x \in {} is false + // E x \in {} is false + assertTlaExAndRestore(rewriter, nextState.setRex(tla.not(nextState.ex).typed(BoolT1()))) } - test("""SE-EX skolem: \E i \in Nat: i = 10 /\ x' \in {i}""") { + test("""skolem: \E i \in Nat: i = 10 /\ x' \in {i}""") { // this works for skolem constants only val ex = - OperEx(BmcOper.skolem, - tla.exists(tla.name("i"), ValEx(TlaNatSet), - tla.and( - tla.eql(tla.name("i"), tla.int(10)), - tla.assignPrime(tla.name("x"), tla.name("i")) - ))) + tla + .apalacheSkolem(tla.exists(tla.name("i") ? "i", tla.natSet() ? "I", + tla.and( + tla.eql(tla.name("i") ? "i", tla.int(10)) ? "b", + tla.assign(tla.prime(tla.name("x") ? "i") ? "i", tla.name("i") ? "i") ? "b" + ) ? "b") ? "b") + .typed(boolTypes, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() @@ -428,29 +520,34 @@ class TestSymbStateRewriterBool extends RewriterBase with TestingPredefs with Be assert(solverContext.sat()) solverContext.assertGroundExpr(nextState.ex) val xp = nextState.binding("x'") - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(xp.toNameEx, tla.int(10)))) + val eql = tla + .eql(xp.toNameEx ? "i", tla.int(10)) + .typed(boolTypes, "b") + assertTlaExAndRestore(rewriter, nextState.setRex(eql)) } - test("""SE-EX skolemization over range: \E i \in a..b: i % 3 = 1 /\ x' \in {i}""") { + test("""skolemization over range: \E i \in a..b: i % 3 = 1 /\ x' \in {i}""") { // this works for skolem constants only val ex = - OperEx(BmcOper.skolem, - tla.exists( - tla.name("i"), - tla.dotdot(tla.name("a"), tla.name("b")), - tla.and( - tla.eql(tla.mod(tla.name("i"), tla.int(3)), tla.int(1)), - tla.assignPrime(tla.name("x"), tla.name("i")) - ) - )) /// + tla + .apalacheSkolem( + tla.exists( + tla.name("i") ? "i", + tla.dotdot(tla.name("a") ? "i", tla.name("b") ? "i") ? "I", + tla.and( + tla.eql(tla.mod(tla.name("i") ? "i", tla.int(3)) ? "i", tla.int(1)) ? "b", + tla.assign(tla.prime(tla.name("x") ? "i") ? "i", tla.name("i") ? "i") ? "b" + ) ? "b" + ) ? "b") + .typed(boolTypes, "b") val rewriter = create() // rewrite 5 and 9 first, to produce a and b - var state = new SymbState(tla.int(5), arena, Binding()) + var state = new SymbState(tla.int(5).typed(), arena, Binding()) state = rewriter.rewriteUntilDone(state) val aCell = state.asCell - state = rewriter.rewriteUntilDone(state.setRex(tla.int(9))) + state = rewriter.rewriteUntilDone(state.setRex(tla.int(9).typed())) val bCell = state.asCell val binding: Binding = Binding("a" -> aCell, "b" -> bCell) @@ -458,12 +555,18 @@ class TestSymbStateRewriterBool extends RewriterBase with TestingPredefs with Be assert(solverContext.sat()) solverContext.assertGroundExpr(nextState.ex) val xp = nextState.binding("x'") - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(xp.toNameEx, tla.int(7)))) + val eq = tla + .eql(xp.toNameEx ? "i", tla.int(7)) + .typed(boolTypes, "b") + assertTlaExAndRestore(rewriter, nextState.setRex(eq)) } - test("""SE-ALL3: \A x \in {1, 2, 3}: x < 10 ~~> $B$k""") { + test("""\A x \in {1, 2, 3}: x < 10 ~~> $B$k""") { + val set123 = tla.enumSet(tla.int(1), tla.int(2), tla.int(3)).typed(SetT1(IntT1())) val ex = - tla.forall(tla.name("x"), tla.enumSet(tla.int(1), tla.int(2), tla.int(3)), tla.lt(tla.name("x"), tla.int(10))) + tla + .forall(tla.name("x") ? "i", set123, tla.lt(tla.name("x") ? "i", tla.int(10)) ? "b") + .typed(boolTypes, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -472,13 +575,16 @@ class TestSymbStateRewriterBool extends RewriterBase with TestingPredefs with Be solverContext.assertGroundExpr(nextState.ex) assert(solverContext.sat()) rewriter.pop() - solverContext.assertGroundExpr(tla.not(nextState.ex)) + solverContext.assertGroundExpr(tla.not(nextState.ex).typed(BoolT1())) assertUnsatOrExplain(rewriter, nextState) } - test("""SE-ALL3: \A x \in {1, 2, 3}: x > 2 ~~> $B$k""") { + test("""\A x \in {1, 2, 3}: x > 2 ~~> $B$k""") { + val set123 = tla.enumSet(tla.int(1), tla.int(2), tla.int(3)).typed(SetT1(IntT1())) val ex = - tla.forall(tla.name("x"), tla.enumSet(tla.int(1), tla.int(2), tla.int(3)), tla.gt(tla.name("x"), tla.int(2))) + tla + .forall(tla.name("x") ? "i", set123, tla.gt(tla.name("x") ? "i", tla.int(2)) ? "b") + .typed(boolTypes, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -487,7 +593,7 @@ class TestSymbStateRewriterBool extends RewriterBase with TestingPredefs with Be solverContext.assertGroundExpr(nextState.ex) assertUnsatOrExplain(rewriter, nextState) rewriter.pop() - solverContext.assertGroundExpr(tla.not(nextState.ex)) + solverContext.assertGroundExpr(tla.not(nextState.ex).typed(BoolT1())) assert(solverContext.sat()) } } diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterChoose.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterChoose.scala index b21dc49e8e..b47b6e3689 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterChoose.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterChoose.scala @@ -1,27 +1,35 @@ package at.forsyte.apalache.tla.bmcmt -import at.forsyte.apalache.tla.bmcmt.types.{AnnotationParser, FinSetT, IntT} -import at.forsyte.apalache.tla.lir.TestingPredefs -import at.forsyte.apalache.tla.lir.convenience.tla -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.lir.TypedPredefs._ +import at.forsyte.apalache.tla.lir.convenience.tla._ +import at.forsyte.apalache.tla.lir.{BoolT1, IntT1, SetT1, TestingPredefs} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestSymbStateRewriterChoose extends RewriterBase with TestingPredefs { + private val types = Map( + "b" -> BoolT1(), + "i" -> IntT1(), + "I" -> SetT1(IntT1()) + ) + test("""CHOOSE x \in {1, 2, 3}: x > 1""") { val ex = - tla.choose(tla.name("x"), tla.enumSet(tla.int(1), tla.int(2), tla.int(3)), tla.gt(tla.name("x"), tla.int(1))) + choose(name("x") ? "i", enumSet(int(1), int(2), int(3)) ? "I", gt(name("x") ? "i", int(1)) ? "b") + .typed(types, "i") val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) assert(solverContext.sat()) + def assertEq(i: Int): SymbState = { - val ns = rewriter.rewriteUntilDone(nextState.setRex(tla.eql(nextState.ex, tla.int(i)))) + val ns = rewriter.rewriteUntilDone(nextState.setRex(eql(nextState.ex ? "i", int(i)).typed(types, "b"))) solverContext.assertGroundExpr(ns.ex) ns } + // in our implementation, CHOOSE is non-deterministic, so all three results below are possible rewriter.push() assertEq(3) assert(solverContext.sat()) @@ -36,7 +44,8 @@ class TestSymbStateRewriterChoose extends RewriterBase with TestingPredefs { } test("""CHOOSE x \in {1}: x > 1""") { - val ex = tla.choose(tla.name("x"), tla.enumSet(tla.int(1)), tla.gt(tla.name("x"), tla.int(1))) + val ex = choose(name("x") ? "i", enumSet(int(1)) ? "I", gt(name("x"), int(1)) ? "b") + .typed(types, "i") val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -48,15 +57,18 @@ class TestSymbStateRewriterChoose extends RewriterBase with TestingPredefs { } test("""CHOOSE x \in {}: x > 1""") { - val ex = tla.choose(tla.name("x"), tla.withType(tla.enumSet(), AnnotationParser.toTla(FinSetT(IntT()))), - tla.gt(tla.name("x"), tla.int(1))) + val ex = choose(name("x") ? "i", enumSet() ? "I", gt(name("x") ? "i", int(1)) ? "b") + .typed(types, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) // the buggy implementation of choose fails on a dynamically empty set assert(solverContext.sat()) + def assertEq(i: Int): SymbState = { - val ns = rewriter.rewriteUntilDone(nextState.setRex(tla.eql(nextState.ex, tla.int(i)))) + val eq = eql(nextState.ex ? "i", int(i)) + .typed(types, "b") + val ns = rewriter.rewriteUntilDone(nextState.setRex(eq)) solverContext.assertGroundExpr(ns.ex) ns } diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterControl.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterControl.scala index b71e4b42ae..de8a4e06b2 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterControl.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterControl.scala @@ -1,312 +1,126 @@ package at.forsyte.apalache.tla.bmcmt -import at.forsyte.apalache.tla.bmcmt.smt.{PreproSolverContext, Z3SolverContext} -import at.forsyte.apalache.tla.bmcmt.types.{FailPredT, IntT} +import at.forsyte.apalache.tla.lir.TypedPredefs._ import at.forsyte.apalache.tla.lir._ -import at.forsyte.apalache.tla.lir.convenience._ +import at.forsyte.apalache.tla.lir.convenience.tla._ import at.forsyte.apalache.tla.lir.oper.TlaSetOper -import at.forsyte.apalache.tla.lir.UntypedPredefs._ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestSymbStateRewriterControl extends RewriterBase with TestingPredefs { - test("""SE-ITE[1-4]: IF 3 > 2 THEN 2 < 4 ELSE 5 < 1 ~~> $C$k""") { - val pred = tla.gt(tla.int(3), tla.int(2)) - val e1 = tla.lt(tla.int(2), tla.int(4)) - val e2 = tla.lt(tla.int(5), tla.int(1)) - val ite = tla.ite(pred, e1, e2) - - val state = new SymbState(ite, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case res @ NameEx(name) => - rewriter.push() - solverContext.assertGroundExpr(res) - assert(solverContext.sat()) - rewriter.pop() - solverContext.assertGroundExpr(tla.not(res)) - assert(!solverContext.sat()) - - case _ => - fail("Unexpected rewriting result") - } + private val types = Map( + "b" -> BoolT1(), + "i" -> IntT1(), + "I" -> SetT1(IntT1()), + "O" -> OperT1(Seq(), IntT1()) + ) + + test("""IF 3 > 2 THEN 2 < 4 ELSE 5 < 1""") { + val pred = gt(int(3), int(2)) + val e1 = lt(int(2), int(4)) + val e2 = lt(int(5), int(1)) + val ifThenElse = ite(pred ? "b", e1 ? "b", e2 ? "b") + .typed(types, "b") + + val state = new SymbState(ifThenElse, arena, Binding()) + assertTlaExAndRestore(create(), state.setRex(ifThenElse)) } - test("""SE-ITE[1-4]: IF 3 < 2 THEN 2 < 4 ELSE 5 < 1 ~~> $C$k""") { - val pred = tla.lt(tla.int(3), tla.int(2)) - val e1 = tla.lt(tla.int(2), tla.int(4)) - val e2 = tla.lt(tla.int(5), tla.int(1)) - val ite = tla.ite(pred, e1, e2) + test("""IF 3 < 2 THEN 2 < 4 ELSE 5 < 1""") { + val pred = lt(int(3), int(2)) + val e1 = lt(int(2), int(4)) + val e2 = lt(int(5), int(1)) + val ifThenElse = not(ite(pred ? "b", e1 ? "b", e2 ? "b") ? "b") + .typed(types, "b") - val state = new SymbState(ite, arena, Binding()) + val state = new SymbState(ifThenElse, arena, Binding()) val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case res @ NameEx(name) => - rewriter.push() - solverContext.assertGroundExpr(tla.not(res)) - assert(solverContext.sat()) - rewriter.pop() - solverContext.assertGroundExpr(res) - assert(!solverContext.sat()) - - case _ => - fail("Unexpected rewriting result") - } + assertTlaExAndRestore(create(), state.setRex(ifThenElse)) } - test("""SE-ITE[1-4]: IF 3 > 2 THEN 4 ELSE 1 ~~> $C$k""") { - val pred = tla.gt(tla.int(3), tla.int(2)) - val e1 = tla.int(4) - val e2 = tla.int(1) - val ite = tla.ite(pred, e1, e2) + test("""IF 3 > 2 THEN 4 ELSE 1""") { + val pred = gt(int(3), int(2)) + val e1 = int(4) + val e2 = int(1) + val ifThenElse = ite(pred ? "b", e1, e2) ? "i" + val eq4 = eql(ifThenElse, int(4)) + .typed(types, "b") - val state = new SymbState(ite, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case res @ NameEx(name) => - rewriter.push() - solverContext.assertGroundExpr(tla.eql(tla.int(4), res)) - assert(solverContext.sat()) - rewriter.pop() - solverContext.assertGroundExpr(tla.eql(tla.int(1), res)) - assert(!solverContext.sat()) - - case _ => - fail("Unexpected rewriting result") - } + val state = new SymbState(eq4, arena, Binding()) + assertTlaExAndRestore(create(), state) } - test("""SE-ITE[1-4]: IF 3 < 2 THEN 4 ELSE 1 ~~> $C$k""") { - val pred = tla.lt(tla.int(3), tla.int(2)) - val e1 = tla.int(4) - val e2 = tla.int(1) - val ite = tla.ite(pred, e1, e2) - - val state = new SymbState(ite, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case res @ NameEx(name) => - rewriter.push() - solverContext.assertGroundExpr(tla.eql(tla.int(1), res)) - assert(solverContext.sat()) - rewriter.pop() - solverContext.assertGroundExpr(tla.eql(tla.int(4), res)) - assert(!solverContext.sat()) + test("""IF 3 < 2 THEN 4 ELSE 1""") { + val pred = lt(int(3), int(2)) + val e1 = int(4) + val e2 = int(1) + val ifThenElse = ite(pred ? "b", e1, e2) ? "i" + val eq4 = eql(ifThenElse, int(1)) + .typed(types, "b") - case _ => - fail("Unexpected rewriting result") - } + val state = new SymbState(eq4, arena, Binding()) + assertTlaExAndRestore(create(), state) } - test("""SE-ITE[5]: IF 3 < 2 THEN {1, 2} ELSE {2, 3} ~~> {2, 3}""") { - val pred = tla.lt(tla.int(3), tla.int(2)) - val e1 = tla.enumSet(tla.int(1), tla.int(2)) - val e2 = tla.enumSet(tla.int(2), tla.int(3)) - val ite = tla.ite(pred, e1, e2) - - val state = new SymbState(ite, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case res @ NameEx(name) => - assert(solverContext.sat()) - rewriter.push() - val eqState = rewriter.rewriteUntilDone(nextState.setRex(tla.eql(res, e2))) - solverContext.assertGroundExpr(eqState.ex) - assert(solverContext.sat()) - rewriter.pop() - val neqState = rewriter.rewriteUntilDone(nextState.setRex(tla.eql(res, e1))) - solverContext.assertGroundExpr(neqState.ex) - assert(!solverContext.sat()) - - case _ => - fail("Unexpected rewriting result") - } + test("""IF 3 < 2 THEN {1, 2} ELSE {2, 3} equals {2, 3}""") { + val pred = lt(int(3), int(2)) + val e1 = enumSet(int(1), int(2)) + val e2 = enumSet(int(2), int(3)) + val ifThenElse = ite(pred ? "b", e1 ? "I", e2 ? "I") + .typed(types, "I") + val eq23 = eql(ifThenElse, e2 ? "I") + .typed(types, "b") + + val state = new SymbState(eq23, arena, Binding()) + assertTlaExAndRestore(create(), state) } - test("""SE-ITE[5]: IF 1 = 1 THEN {2} ELSE {1} ] ~~> $C$k""") { - def mkSet(elems: TlaEx*) = OperEx(TlaSetOper.enumSet, elems: _*) - - val ite = tla.ite(tla.eql(tla.int(1), tla.int(1)), tla.enumSet(tla.int(2)), tla.enumSet(tla.int(1))) - val eq = tla.eql(tla.enumSet(tla.int(2)), ite) + test("""IF 1 = 1 THEN {2} ELSE {1} ]""") { + val ifThenElse = ite(eql(int(1), int(1)) ? "b", enumSet(int(2)) ? "I", enumSet(int(1)) ? "I") + val eq = eql(enumSet(int(2)) ? "I", ifThenElse ? "I") + .typed(types, "b") val state = new SymbState(eq, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case membershipEx @ NameEx(name) => - assert(solverContext.sat()) - solverContext.push() - solverContext.assertGroundExpr(tla.not(nextState.ex)) - assertUnsatOrExplain(rewriter, nextState) - solverContext.pop() - solverContext.push() - solverContext.assertGroundExpr(nextState.ex) - assert(solverContext.sat()) - solverContext.pop() - - case _ => - fail("Unexpected rewriting result") - } + assertTlaExAndRestore(create(), state) } test("""SE-ITE[5]: IF 2 < 3 THEN {1, 2} ELSE {2, 3} ~~> {1, 2}""") { - val pred = tla.lt(tla.int(2), tla.int(3)) - val e1 = tla.enumSet(tla.int(1), tla.int(2)) - val e2 = tla.enumSet(tla.int(2), tla.int(3)) - val ite = tla.ite(pred, e1, e2) - - val state = new SymbState(ite, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case res @ NameEx(name) => - assert(solverContext.sat()) - rewriter.push() - val eqState = rewriter.rewriteUntilDone(nextState.setRex(tla.eql(res, e1))) - solverContext.assertGroundExpr(eqState.ex) - assert(solverContext.sat()) - rewriter.pop() - val neqState = rewriter.rewriteUntilDone(nextState.setRex(tla.eql(res, e2))) - solverContext.assertGroundExpr(neqState.ex) - assert(!solverContext.sat()) + val pred = lt(int(2), int(3)) + val e1 = enumSet(int(1), int(2)) + val e2 = enumSet(int(2), int(3)) + val ifThenElse = ite(pred ? "b", e1 ? "I", e2 ? "I") + .typed(types, "I") + val eq = eql(ifThenElse, e1 ? "I") + .typed(types, "b") - case _ => - fail("Unexpected rewriting result") - } + val state = new SymbState(eq, arena, Binding()) + assertTlaExAndRestore(create(), state) } - test("""SE-ITE[1-4]: 1 + (IF 3 < 2 THEN 4 ELSE 1) ~~> $C$k""") { - val pred = tla.lt(tla.int(3), tla.int(2)) - val e1 = tla.int(4) - val e2 = tla.int(1) - val ite = tla.plus(tla.int(1), tla.ite(pred, e1, e2)) - - val state = new SymbState(ite, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case res @ NameEx(name) => - rewriter.push() - solverContext.assertGroundExpr(tla.eql(tla.int(2), res)) - assert(solverContext.sat()) - rewriter.pop() - solverContext.assertGroundExpr(tla.eql(tla.int(5), res)) - assert(!solverContext.sat()) + test("""1 + (IF 3 < 2 THEN 4 ELSE 1)""") { + val pred = lt(int(3), int(2)) + val e1 = int(4) + val e2 = int(1) + val addition = plus(int(1), ite(pred ? "b", e1 ? "i", e2 ? "i") ? "i") + val eq = eql(addition ? "i", int(2)) + .typed(types, "b") - case _ => - fail("Unexpected rewriting result") - } + val state = new SymbState(eq, arena, Binding()) + assertTlaExAndRestore(create(), state) } // LET-IN is often used to cache computation results - test("""LET A == 1 + 2 IN 1 + A ~~> 4""") { - val decl = TlaOperDecl("A", List(), tla.plus(tla.int(1), tla.int(2))) - val letIn = tla.letIn(tla.plus(tla.int(1), tla.appDecl(decl)), decl) - val state = new SymbState(letIn, arena, Binding()) + test("""LET A == 1 + 2 IN 1 + A equals 4""") { + val decl = declOp("A", plus(int(1), int(2)) ? "i") + .typedOperDecl(types, "O") + val let = letIn(plus(int(1), appOp(name("A") ? "O") ? "i") ? "i", decl) + .typed(types, "i") + val state = new SymbState(let, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(nextState.ex, tla.int(4)))) - } - - // handled by Keramelizer - // CASE i = 1 -> 2 [] i = 2 -> 3 [] i = 3 -> 1] - ignore("""SE-CASE1: CASE i = 1 -> 2 [] i = 2 -> 3 [] i = 3 -> 1]""") { - def guard(arg: Int) = tla.eql(tla.name("i"), tla.int(arg)) - - val caseEx = tla.caseAny(guard(1), tla.int(2), guard(2), tla.int(3), guard(3), tla.int(1)) - - def caseExEqConst(i: Int) = tla.eql(tla.int(i), caseEx) - - for (i <- List(1, 2, 3)) { - // reinitialize the arena and the solver - solverContext = new PreproSolverContext(solverContext) - arena = Arena.create(solverContext) - arena = arena.appendCell(IntT()) - val icell = arena.topCell - val binding = Binding("i" -> icell) - val state = new SymbState(caseEx, arena, binding) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case res @ NameEx(name) => - rewriter.push() - solverContext.assertGroundExpr(tla.eql(icell.toNameEx, tla.int(i))) - solverContext.assertGroundExpr(tla.eql(tla.int(1 + (i % 3)), res)) - assert(solverContext.sat()) - rewriter.pop() - rewriter.push() - val failureOccurs = tla.or(nextState.arena.findCellsByType(FailPredT()).map(_.toNameEx): _*) - solverContext.assertGroundExpr(failureOccurs) - assert( - solverContext.sat()) // this possible since there is no OTHER case and the constraints do not restrict us - rewriter.pop() - rewriter.push() - solverContext.assertGroundExpr(tla.eql(icell.toNameEx, tla.int(i))) - solverContext.assertGroundExpr(tla.not(failureOccurs)) - solverContext.assertGroundExpr(tla.eql(tla.int(i), res)) - assert(!solverContext.sat()) - rewriter.pop() - - case _ => - fail("Unexpected rewriting result") - } - - arena = nextState.arena // update the arena for the next iteration - } + val eq = eql(nextState.ex ? "i", int(4)) + .typed(types, "b") + assertTlaExAndRestore(rewriter, nextState.setRex(eq)) } - - // handled by Keramelizer - // CASE i = 1 -> 2 [] i = 2 -> 3 [] i = 3 -> 1 [] OTHER -> 4] - ignore("""SE-CASE1: CASE i = 1 -> 2 [] i = 2 -> 3 [] i = 3 -> 1 [] OTHER -> 4]""") { - def guard(arg: Int) = tla.eql(tla.name("i"), tla.int(arg)) - - val caseEx = tla.caseOther(tla.int(4), guard(1), tla.int(2), guard(2), tla.int(3), guard(3), tla.int(1)) - - def caseExEqConst(i: Int) = tla.eql(tla.int(i), caseEx) - - for (i <- List(1, 2, 3, 99)) { - // reinitialize the arena and the solver - solverContext = new PreproSolverContext(solverContext) - arena = Arena.create(solverContext) - arena = arena.appendCell(IntT()) - val icell = arena.topCell - val binding = Binding("i" -> icell) - val state = new SymbState(caseEx, arena, binding) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case res @ NameEx(name) => - rewriter.push() - solverContext.assertGroundExpr(tla.eql(icell.toNameEx, tla.int(i))) - val expectedValue = if (i <= 3) 1 + (i % 3) else 4 - solverContext.assertGroundExpr(tla.eql(tla.int(expectedValue), res)) - assert(solverContext.sat()) - rewriter.pop() - rewriter.push() - val failureOccurs = tla.or(nextState.arena.findCellsByType(FailPredT()).map(_.toNameEx): _*) - solverContext.assertGroundExpr(failureOccurs) - assert(!solverContext.sat()) // no failure should occur, as there is the OTHER case - rewriter.pop() - rewriter.push() - solverContext.assertGroundExpr(tla.eql(icell.toNameEx, tla.int(i))) - solverContext.assertGroundExpr(tla.not(failureOccurs)) - solverContext.assertGroundExpr(tla.eql(tla.int(i), res)) - assert(!solverContext.sat()) - rewriter.pop() - - case _ => - fail("Unexpected rewriting result") - } - - arena = nextState.arena // update the arena for the next iteration - } - - } - } diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterExpand.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterExpand.scala index 8cb721da9c..834f788a3d 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterExpand.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterExpand.scala @@ -1,55 +1,41 @@ package at.forsyte.apalache.tla.bmcmt -import at.forsyte.apalache.tla.bmcmt.types._ -import at.forsyte.apalache.tla.lir.{NameEx, OperEx, TlaEx} -import at.forsyte.apalache.tla.lir.convenience.tla -import at.forsyte.apalache.tla.lir.oper.BmcOper -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.lir.TypedPredefs._ +import at.forsyte.apalache.tla.lir.convenience.tla._ +import at.forsyte.apalache.tla.lir.{BoolT1, FunT1, IntT1, SetT1} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestSymbStateRewriterExpand extends RewriterBase { + private val types = Map( + "b" -> BoolT1(), + "i" -> IntT1(), + "I" -> SetT1(IntT1()), + "II" -> SetT1(SetT1(IntT1())), + "B" -> SetT1(BoolT1()), + "i_TO_b" -> SetT1(FunT1(IntT1(), BoolT1())) + ) + test("""Expand(SUBSET {1, 2})""") { - val baseset = tla.enumSet(tla.int(1), tla.int(2)) - val expandPowset = OperEx(BmcOper.expand, tla.powSet(baseset)) - val state = new SymbState(expandPowset, arena, Binding()) - val rewriter = create() - var nextState = rewriter.rewriteUntilDone(state) - val powCell = nextState.asCell - // check equality - val eq = tla.eql(nextState.ex, - tla.enumSet(tla.withType(tla.enumSet(), AnnotationParser.toTla(FinSetT(IntT()))), tla.enumSet(tla.int(1)), - tla.enumSet(tla.int(2)), tla.enumSet(tla.int(1), tla.int(2)))) - assertTlaExAndRestore(rewriter, nextState.setRex(eq)) - } + val baseset = enumSet(int(1), int(2)) + val expandPowset = apalacheExpand(powSet(baseset ? "I") ? "II") + .typed(types, "II") + val subsets = enumSet(enumSet() ? "I", enumSet(int(1)) ? "I", enumSet(int(2)) ? "I", enumSet(int(1), int(2)) ? "I") + val eq = eql(expandPowset, subsets ? "II") + .typed(types, "b") - test("""Expand([{1, 2, 3} -> {FALSE, TRUE}]) should fail""") { - val domain = tla.enumSet(tla.int(1), tla.int(2), tla.int(3)) - val codomain = tla.enumSet(tla.bool(false), tla.bool(true)) - val funSet = OperEx(BmcOper.expand, tla.funSet(domain, codomain)) - val state = new SymbState(funSet, arena, Binding()) - val rewriter = create() - assertThrows[RewriterException](rewriter.rewriteUntilDone(state)) + val state = new SymbState(eq, arena, Binding()) + assertTlaExAndRestore(create(), state) } - // Constructing an explicit set of functions is, of course, expensive. But it should work for small values. - // Left for the future... - ignore("""Expand([{1, 2} -> {FALSE, TRUE}]) should work""") { - val domain = tla.enumSet(tla.int(1), tla.int(2)) - val codomain = tla.enumSet(tla.bool(false), tla.bool(true)) - val funSet = OperEx(BmcOper.expand, tla.funSet(domain, codomain)) - val state = new SymbState(funSet, arena, Binding()) + test("""Expand([{1, 2, 3} -> {FALSE, TRUE}]) fails as unsupported""") { + val domain = enumSet(int(1), int(2), int(3)) + val codomain = enumSet(bool(false), bool(true)) + val set = apalacheExpand(funSet(domain ? "I", codomain ? "B") ? "i_TO_b") + .typed(types, "i_TO_b") + val state = new SymbState(set, arena, Binding()) val rewriter = create() - var nextState = rewriter.rewriteUntilDone(state) - val funSetCell = nextState.asCell - def mkFun(v1: Boolean, v2: Boolean): TlaEx = { - val mapEx = tla.ite(tla.eql(NameEx("x"), tla.int(1)), tla.bool(v1), tla.bool(v2)) - tla.funDef(mapEx, tla.name("x"), domain) - } - - val expected = tla.enumSet(mkFun(false, false), mkFun(false, true), mkFun(true, false), mkFun(true, true)) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(expected, funSetCell.toNameEx))) + assertThrows[RewriterException](rewriter.rewriteUntilDone(state)) } - } diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterFiniteSets.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterFiniteSets.scala index 35c18bf821..0f9711c70b 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterFiniteSets.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterFiniteSets.scala @@ -1,116 +1,141 @@ package at.forsyte.apalache.tla.bmcmt -import at.forsyte.apalache.tla.lir.{OperEx, TlaEx} -import at.forsyte.apalache.tla.lir.convenience.tla -import at.forsyte.apalache.tla.lir.oper.BmcOper -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.lir.TypedPredefs._ +import at.forsyte.apalache.tla.lir.convenience.tla._ +import at.forsyte.apalache.tla.lir.{BoolT1, IntT1, SetT1, TlaEx} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestSymbStateRewriterFiniteSets extends RewriterBase { + private val types = Map( + "b" -> BoolT1(), + "i" -> IntT1(), + "I" -> SetT1(IntT1()) + ) + test("""Cardinality({1, 2, 3}) = 3""") { - val set = tla.enumSet(1.to(3).map(tla.int): _*) - val card = tla.card(set) - val state = new SymbState(card, arena, Binding()) + val set = enumSet(1.to(3).map(int): _*) + val cardinality = card(set ? "I") + .typed(types, "i") + val eq = eql(cardinality, int(3)) + .typed(types, "b") + + val state = new SymbState(eq, arena, Binding()) val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - assert(solverContext.sat()) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.int(3), nextState.ex))) + assertTlaExAndRestore(create(), state) } test("""Cardinality({1, 2, 2, 2, 3, 3}) = 3""") { - val set = tla.enumSet(Seq(1, 2, 2, 2, 3, 3).map(tla.int): _*) - val card = tla.card(set) - val state = new SymbState(card, arena, Binding()) + val set = enumSet(Seq(1, 2, 2, 2, 3, 3).map(int): _*) + val cardinality = card(set ? "I") + .typed(types, "i") + val eq = eql(cardinality, int(3)) + .typed(types, "b") + + val state = new SymbState(eq, arena, Binding()) val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - assert(solverContext.sat()) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.int(3), nextState.ex))) + assertTlaExAndRestore(create(), state) } - test("""BMC!ConstCard(Cardinality({1, 2, 3}) >= 3)""") { - val set = tla.enumSet(1.to(3).map(tla.int): _*) - val cardCmp = OperEx(BmcOper.constCard, tla.ge(tla.card(set), tla.int(3))) + test("""Apalache!ConstCard(Cardinality({1, 2, 3}) >= 3)""") { + val set = enumSet(1.to(3).map(int): _*) + val cardCmp = apalacheConstCard(ge(card(set ? "I") ? "i", int(3)) ? "b") + .typed(types, "b") val state = new SymbState(cardCmp, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) + // note that this optimization only works in the positive form. Its negation may be SAT. assert(solverContext.sat()) solverContext.assertGroundExpr(nextState.ex) assert(solverContext.sat()) } - test("""BMC!ConstCard(Cardinality({1, 2, 3}) >= 4)""") { - val set = tla.enumSet(1.to(3).map(tla.int): _*) - val cardCmp = OperEx(BmcOper.constCard, tla.ge(tla.card(set), tla.int(4))) + test("""Apalache!ConstCard(Cardinality({1, 2, 3}) >= 4)""") { + val set = enumSet(1.to(3).map(int): _*) + val cardCmp = apalacheConstCard(ge(card(set ? "I") ? "i", int(4)) ? "b") + .typed(types, "b") val state = new SymbState(cardCmp, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) + // note that this optimization only works in the positive form. Its negation may be SAT. assert(solverContext.sat()) solverContext.assertGroundExpr(nextState.ex) assert(!solverContext.sat()) } - test("""BMC!ConstCard(Cardinality({1, 2, 2, 3}) >= 4)""") { - val set = tla.enumSet(List(1, 2, 2, 3).map(tla.int): _*) - val cardCmp = OperEx(BmcOper.constCard, tla.ge(tla.card(set), tla.int(4))) + test("""Apalache!ConstCard(Cardinality({1, 2, 2, 3}) >= 4)""") { + val set = enumSet(Seq(1, 2, 2, 3).map(int): _*) + val cardCmp = apalacheConstCard(ge(card(set ? "I") ? "i", int(4)) ? "b") + .typed(types, "b") val state = new SymbState(cardCmp, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) + // note that this optimization only works in the positive form. Its negation may be SAT. assert(solverContext.sat()) solverContext.assertGroundExpr(nextState.ex) assert(!solverContext.sat()) } - test("""BMC!ConstCard(Cardinality({1, 2, 2, 3, 3}) >= 4)""") { - val set = tla.enumSet(List(1, 2, 2, 3).map(tla.int): _*) - val cardCmp = OperEx(BmcOper.constCard, tla.ge(tla.card(set), tla.int(4))) + test("""Apalache!ConstCard(Cardinality({1, 2, 2, 3, 3}) >= 4)""") { + val set = enumSet(Seq(1, 2, 2, 3, 3).map(int): _*) + val cardCmp = apalacheConstCard(ge(card(set ? "I") ? "i", int(4)) ? "b") + .typed(types, "b") val state = new SymbState(cardCmp, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) + // note that this optimization only works in the positive form. Its negation may be SAT. assert(solverContext.sat()) solverContext.assertGroundExpr(nextState.ex) assert(!solverContext.sat()) } - test("""BMC!ConstCard(Cardinality({}) >= 0)""") { - val set = tla.enumSet() - val cardCmp = OperEx(BmcOper.constCard, tla.ge(tla.card(set), tla.int(0))) + test("""Apalache!ConstCard(Cardinality({}) >= 0)""") { + val set = enumSet() + val cardCmp = apalacheConstCard(ge(card(set ? "I") ? "i", int(0)) ? "b") + .typed(types, "b") val state = new SymbState(cardCmp, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) + // note that this optimization only works in the positive form. Its negation may be SAT. assert(solverContext.sat()) solverContext.assertGroundExpr(nextState.ex) assert(solverContext.sat()) } - test("""BMC!ConstCard(Cardinality({x \in {}: FALSE}) >= 0)""") { - val set = tla.filter(tla.name("x"), tla.enumSet(), tla.bool(false)) - val cardCmp = OperEx(BmcOper.constCard, tla.ge(tla.card(set), tla.int(0))) + test("""Apalache!ConstCard(Cardinality({}) >= 1)""") { + val set = enumSet() + val cardCmp = apalacheConstCard(ge(card(set ? "I") ? "i", int(1)) ? "b") + .typed(types, "b") val state = new SymbState(cardCmp, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) + // note that this optimization only works in the positive form. Its negation may be SAT. assert(solverContext.sat()) solverContext.assertGroundExpr(nextState.ex) - assert(solverContext.sat()) + assert(!solverContext.sat()) } - test("""BMC!ConstCard(Cardinality({x \in {}: FALSE}) >= 1)""") { - val set = tla.filter(tla.name("x"), tla.enumSet(), tla.bool(false)) - val cardCmp = OperEx(BmcOper.constCard, tla.ge(tla.card(set), tla.int(1))) + test("""Apalache!ConstCard(Cardinality({x \in {}: FALSE}) >= 0)""") { + val set = filter(name("x") ? "i", enumSet() ? "I", bool(false)) + val cardCmp = apalacheConstCard(ge(card(set ? "I") ? "i", int(0)) ? "b") + .typed(types, "b") val state = new SymbState(cardCmp, arena, Binding()) val rewriter = create() + // note that this optimization only works in the positive form. Its negation may be SAT. val nextState = rewriter.rewriteUntilDone(state) assert(solverContext.sat()) solverContext.assertGroundExpr(nextState.ex) - assert(!solverContext.sat()) + assert(solverContext.sat()) } - test("""BMC!ConstCard(Cardinality({}) >= 1)""") { - val set = tla.enumSet() - val cardCmp = OperEx(BmcOper.constCard, tla.ge(tla.card(set), tla.int(1))) + test("""Apalache!ConstCard(Cardinality({x \in {}: FALSE}) >= 1)""") { + val set = filter(name("x") ? "i", enumSet() ? "I", bool(false)) + val cardCmp = apalacheConstCard(ge(card(set ? "I") ? "i", int(1)) ? "b") + .typed(types, "b") val state = new SymbState(cardCmp, arena, Binding()) val rewriter = create() + // note that this optimization only works in the positive form. Its negation may be SAT. val nextState = rewriter.rewriteUntilDone(state) assert(solverContext.sat()) solverContext.assertGroundExpr(nextState.ex) @@ -119,26 +144,24 @@ class TestSymbStateRewriterFiniteSets extends RewriterBase { test("""Cardinality({1, 2, 3} \ {2}) = 2""") { def setminus(set: TlaEx, intVal: Int): TlaEx = { - tla.filter(tla.name("t"), set, tla.not(tla.eql(tla.name("t"), tla.int(intVal)))) + filter(name("t") ? "i", set ? "I", not(eql(name("t") ? "i", int(intVal)) ? "b") ? "b") + .typed(types, "I") } - val set = setminus(tla.enumSet(1.to(3).map(tla.int): _*), 2) - val card = tla.card(set) - val state = new SymbState(card, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - assert(solverContext.sat()) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.int(2), nextState.ex))) + val set = setminus(enumSet(1.to(3).map(int): _*).typed(types, "I"), 2) + val cardinality = card(set ? "I") + val eq = eql(cardinality ? "i", int(2)) + .typed(types, "b") + + val state = new SymbState(eq, arena, Binding()) + assertTlaExAndRestore(create(), state) } test("""IsFiniteSet({1, 2, 3}) = TRUE""") { - val set = tla.enumSet(1.to(3).map(tla.int): _*) - val card = tla.isFin(set) - val state = new SymbState(card, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - assert(solverContext.sat()) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.bool(true), nextState.ex))) + val set = enumSet(1.to(3).map(int): _*) + val isFiniteSet = isFin(set ? "I") + .typed(types, "b") + val state = new SymbState(isFiniteSet, arena, Binding()) + assertTlaExAndRestore(create(), state) } - } diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterFun.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterFun.scala index d292a7bedd..1e4ea6ba5d 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterFun.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterFun.scala @@ -1,30 +1,38 @@ package at.forsyte.apalache.tla.bmcmt import at.forsyte.apalache.tla.bmcmt.types._ -import at.forsyte.apalache.tla.bmcmt.types.eager.TrivialTypeFinder import at.forsyte.apalache.tla.lir._ -import at.forsyte.apalache.tla.lir.convenience.tla +import at.forsyte.apalache.tla.lir.convenience.tla._ import at.forsyte.apalache.tla.lir.oper._ import at.forsyte.apalache.tla.lir.values.TlaBoolSet import at.forsyte.apalache.tla.lir.values.TlaInt -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.lir.TypedPredefs._ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestSymbStateRewriterFun extends RewriterBase with TestingPredefs { - test("""SE-FUN-CTOR[1-2]: [x \in {1,2,3,4} |-> x / 3: ] ~~> $C$k""") { - def mkSet(elems: TlaEx*) = OperEx(TlaSetOper.enumSet, elems: _*) - - val set = mkSet(tla.int(1), tla.int(2), tla.int(3), tla.int(4)) - val mapping = OperEx(TlaArithOper.div, NameEx("x"), tla.int(3)) - val fun = OperEx(TlaFunOper.funDef, mapping, NameEx("x"), set) + private val types = + Map("b" -> BoolT1(), "B" -> SetT1(BoolT1()), "i" -> IntT1(), "I" -> SetT1(IntT1()), "(i)" -> TupT1(IntT1()), + "i_to_i" -> FunT1(IntT1(), IntT1()), "i_to_I" -> FunT1(IntT1(), SetT1(IntT1())), "r" -> RecT1("a" -> IntT1()), + "s" -> StrT1(), "S" -> SetT1(StrT1()), "(s)" -> TupT1(StrT1()), "i_to_s" -> FunT1(StrT1(), IntT1()), + "s_to_i" -> FunT1(IntT1(), StrT1()), "i_to_r" -> FunT1(IntT1(), RecT1("a" -> IntT1())), + "b_to_b" -> FunT1(BoolT1(), BoolT1()), "b_TO_b" -> SetT1(FunT1(BoolT1(), BoolT1())), + "i_to_b_to_b" -> FunT1(IntT1(), FunT1(BoolT1(), BoolT1()))) + + test("""[x \in {1,2,3,4} |-> x / 3: ]""") { + val set = enumSet(1.to(4).map(int): _*) + .typed(types, "I") + val mapping = div(name("x") ? "i", int(3)) + .typed(types, "i") + val fun = funDef(mapping, name("x") ? "i", set) + .typed(types, "i_to_i") val state = new SymbState(fun, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { - case membershipEx @ NameEx(name) => + case NameEx(name) => assert(solverContext.sat()) val cell = nextState.arena.findCellByName(name) cell.cellType match { @@ -40,24 +48,25 @@ class TestSymbStateRewriterFun extends RewriterBase with TestingPredefs { } } - test("""SE-FUN-CTOR[1-2]: [x \in {1,2,3} |-> IF x = 1 THEN {2} ELSE IF x = 2 THEN {3} ELSE {1} ] ~~> $C$k""") { - val set = tla.enumSet(tla.int(1), tla.int(2), tla.int(3)) + test(""" [x \in {1,2,3} |-> IF x = 1 THEN {2} ELSE IF x = 2 THEN {3} ELSE {1} ]""") { + val set = enumSet(1.to(3).map(int): _*) + .typed(types, "I") - def intSet(i: Int) = tla.enumSet(tla.int(i)) + def intSet(i: Int) = enumSet(int(i)).typed(types, "I") - val mapping = tla.ite( - tla.eql(tla.name("x"), tla.int(1)), + val mapping = ite( + eql(name("x"), int(1)) ? "b", intSet(2), - tla.ite(tla.eql(tla.name("x"), tla.int(2)), intSet(3), intSet(1)) - ) - //// - val fun = tla.funDef(mapping, tla.name("x"), set) + ite(eql(name("x") ? "i", int(2)) ? "b", intSet(3), intSet(1)) ? "I" + ).typed(types, "I") + val fun = funDef(mapping, name("x") ? "i", set) + .typed(types, "i_to_I") val state = new SymbState(fun, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { - case membershipEx @ NameEx(name) => + case NameEx(_) => assert(solverContext.sat()) case _ => @@ -65,25 +74,27 @@ class TestSymbStateRewriterFun extends RewriterBase with TestingPredefs { } } - test("""SE-FUN-CTOR[1-2]: [x \in {1,2} |-> {} ][1] = {} ~~> $C$k""") { - def mkSet(elems: TlaEx*) = OperEx(TlaSetOper.enumSet, elems: _*) - - val set = mkSet(tla.int(1), tla.int(2)) - val mapping = tla.enumSet() - val fun = OperEx(TlaFunOper.funDef, mapping, NameEx("x"), set) - val eq = tla.eql(tla.appFun(fun, tla.int(1)), tla.enumSet()) + test("""[x \in {1,2} |-> {} ][1] = {}""") { + val set = enumSet(int(1), int(2)) + .typed(types, "I") + val mapping = enumSet() + .typed(types, "I") + val fun = funDef(mapping, name("x"), set) + .typed(types, "i_to_I") + val eq = eql(appFun(fun, int(1)) ? "I", enumSet() ? "I") + .typed(types, "b") val state = new SymbState(eq, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { - case result @ NameEx(name) => + case result @ NameEx(_) => solverContext.push() solverContext.assertGroundExpr(result) assert(solverContext.sat()) solverContext.pop() solverContext.push() - solverContext.assertGroundExpr(tla.not(result)) + solverContext.assertGroundExpr(not(result ? "b").typed(types, "b")) assert(!solverContext.sat()) case _ => @@ -91,29 +102,28 @@ class TestSymbStateRewriterFun extends RewriterBase with TestingPredefs { } } - test("""SE-FUN-CTOR[1-2]: [x \in {1,2} |-> IF x = 1 THEN {2} ELSE {1} ][1] ~~> $C$k""") { - def mkSet(elems: TlaEx*) = OperEx(TlaSetOper.enumSet, elems: _*) - - val set = tla.enumSet(tla.int(1), tla.int(2)) - val mapping = tla.ite(tla.eql(tla.name("x"), tla.int(1)), tla.enumSet(tla.int(2)), tla.enumSet(tla.int(1))) - val fun = tla.funDef(mapping, tla.name("x"), set) - val eq = tla.eql(tla.enumSet(tla.int(2)), tla.appFun(fun, tla.int(1))) + test("""[x \in {1,2} |-> IF x = 1 THEN {2} ELSE {1} ][1]""") { + val set = enumSet(int(1), int(2)) + .typed(types, "I") + val mapping = ite(eql(name("x") ? "i", int(1)) ? "b", enumSet(int(2)) ? "I", enumSet(int(1)) ? "I") + .typed(types, "I") + val fun = funDef(mapping, name("x") ? "i", set) + .typed(types, "i_to_I") + val eq = eql(enumSet(int(2)) ? "I", appFun(fun, int(1)) ? "I") + .typed(types, "b") val state = new SymbState(eq, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { - case membershipEx @ NameEx(name) => - assert(solverContext.sat()) - val failureOccurs = tla.or(nextState.arena.findCellsByType(FailPredT()).map(_.toNameEx): _*) - solverContext.assertGroundExpr(tla.not(failureOccurs)) + case NameEx(_) => assert(solverContext.sat()) solverContext.push() solverContext.assertGroundExpr(nextState.ex) assert(solverContext.sat()) solverContext.pop() solverContext.push() - solverContext.assertGroundExpr(tla.not(nextState.ex)) + solverContext.assertGroundExpr(not(nextState.ex ? "b").typed(types, "b")) assert(!solverContext.sat()) solverContext.pop() @@ -123,27 +133,28 @@ class TestSymbStateRewriterFun extends RewriterBase with TestingPredefs { } // regression: this test did not work with EWD840 - test("""SE-FUN-CTOR[1-2]: [x \in {1,2} |-> ["a" |-> x] ][1] ~~> $C$k""") { - val set = tla.enumSet(tla.int(1), tla.int(2)) - val mapping = tla.enumFun(tla.str("a"), tla.name("x")) - val fun = tla.funDef(mapping, tla.name("x"), set) - val eq = tla.eql(tla.enumFun(tla.str("a"), tla.int(1)), tla.appFun(fun, tla.int(1))) + test("""[x \in {1,2} |-> ["a" |-> x] ][1]""") { + val set = enumSet(int(1), int(2)) + .typed(types, "I") + val mapping = enumFun(str("a"), name("x") ? "i") + .typed(types, "r") + val fun = funDef(mapping, name("x") ? "i", set) + .typed(types, "i_to_r") + val eq = eql(enumFun(str("a"), int(1)) ? "r", appFun(fun, int(1)) ? "r") + .typed(types, "b") val state = new SymbState(eq, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { - case membershipEx @ NameEx(name) => - assert(solverContext.sat()) - val failureOccurs = tla.or(nextState.arena.findCellsByType(FailPredT()).map(_.toNameEx): _*) - solverContext.assertGroundExpr(tla.not(failureOccurs)) + case NameEx(_) => assert(solverContext.sat()) solverContext.push() solverContext.assertGroundExpr(nextState.ex) assert(solverContext.sat()) solverContext.pop() solverContext.push() - solverContext.assertGroundExpr(tla.not(nextState.ex)) + solverContext.assertGroundExpr(not(nextState.ex ? "b").typed(types, "b")) assert(!solverContext.sat()) solverContext.pop() @@ -152,28 +163,34 @@ class TestSymbStateRewriterFun extends RewriterBase with TestingPredefs { } } - test("""SE-FUN-APP[1-3]: f[4] ~~> $C$k""") { - def mkSet(elems: TlaEx*) = OperEx(TlaSetOper.enumSet, elems: _*) - - val set = mkSet(ValEx(TlaInt(1)), ValEx(TlaInt(2)), ValEx(TlaInt(3)), ValEx(TlaInt(4))) - val mapping = OperEx(TlaArithOper.mult, NameEx("x"), ValEx(TlaInt(3))) - val fun = OperEx(TlaFunOper.funDef, mapping, NameEx("x"), set) - val app = OperEx(TlaFunOper.app, fun, ValEx(TlaInt(4))) + test("""f[4]""") { + val set = enumSet(1.to(4).map(int): _*) + .typed(types, "I") + val mapping = mult(name("x"), int(3)) + .typed(types, "i") + val fun = funDef(mapping, name("x"), set) + .typed(types, "i_to_i") + val app = appFun(fun, int(4)) + .typed(types, "i") val state = new SymbState(app, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { - case membershipEx @ NameEx(name) => + case NameEx(name) => assert(solverContext.sat()) val cell = nextState.arena.findCellByName(name) cell.cellType match { case IntT() => - solverContext.assertGroundExpr(OperEx(TlaOper.eq, cell.toNameEx, ValEx(TlaInt(12)))) + val eq1 = eql(cell.toNameEx ? "", int(12)) + .typed(types, "b") + solverContext.assertGroundExpr(eq1) rewriter.push() assert(solverContext.sat()) rewriter.pop() - solverContext.assertGroundExpr(OperEx(TlaOper.ne, cell.toNameEx, ValEx(TlaInt(12)))) + val eq2 = neql(cell.toNameEx ? "i", int(12)) + .typed(types, "b") + solverContext.assertGroundExpr(eq2) assert(!solverContext.sat()) case _ => @@ -185,34 +202,24 @@ class TestSymbStateRewriterFun extends RewriterBase with TestingPredefs { } } - test("""SE-FUN-APP[1-3]: [x \in {1, 2} |-> x][4] ~~> failure!""") { - def mkSet(elems: TlaEx*) = OperEx(TlaSetOper.enumSet, elems: _*) - - val set = mkSet(ValEx(TlaInt(1)), ValEx(TlaInt(2))) - val mapping = NameEx("x") - val fun = OperEx(TlaFunOper.funDef, mapping, NameEx("x"), set) - val app = OperEx(TlaFunOper.app, fun, ValEx(TlaInt(4))) + test("""[x \in {1, 2} |-> x][4] ~~> failure!""") { + val set = enumSet(int(1), int(2)) + .typed(types, "I") + val fun = funDef(name("x") ? "i", name("x") ? "i", set) + val app = appFun(fun ? "i_to_i", int(4)) + .typed(types, "i") val state = new SymbState(app, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { - case membershipEx @ NameEx(name) => + case NameEx(_) => // In the previous version, we were using failure predicates to detect failures. // However, they were an unnecessary burden and produced tonnes of constraints. // In the new version, we just return some value, // which is similar to Leslie's interpretation. // The most important thing is that the SMT context is still satisfiable. assert(solverContext.sat()) - /* - // the code with failure predicates - rewriter.push() - val failureOccurs = tla.or(nextState.arena.findCellsByType(FailPredT()).map(_.toNameEx): _*) - solverContext.assertGroundExpr(failureOccurs) - assert(solverContext.sat()) - solverContext.assertGroundExpr(tla.not(failureOccurs)) - assert(!solverContext.sat()) - */ case _ => fail("Unexpected rewriting result") @@ -222,341 +229,109 @@ class TestSymbStateRewriterFun extends RewriterBase with TestingPredefs { // Raft is directly using f @@ e :> r to construct a function g such as: // DOMAIN g = {e} \cup DOMAIN f and g[e] = r and g[a] = f[a] for a \in DOMAIN f // It is trivial to implement this extension with our encoding - test("""SE-FUN-AT-AT: [x \in {1, 2} |-> x] @@ 3 :> 4""") { - def mkSet(elems: TlaEx*) = OperEx(TlaSetOper.enumSet, elems: _*) - - val set = tla.enumSet(tla.int(1), tla.int(2)) - val mapping = NameEx("x") - val fun = tla.funDef(mapping, tla.name("x"), set) - val extFun = OperEx(TlcOper.atat, fun, OperEx(TlcOper.colonGreater, tla.int(3), tla.int(4))) + test("""[x \in {1, 2} |-> x] @@ 3 :> 4""") { + val set = enumSet(int(1), int(2)) + val fun = funDef(name("x") ? "i", name("x") ? "i", set ? "I") + val extFun = atat(fun ? "i_to_i", colonGreater(int(3), int(4)) ? "i_to_i") + .typed(types, "i_to_i") - val state = new SymbState(extFun, arena, Binding()) - val newFun = state.ex val rewriter = create() - val extState = rewriter.rewriteUntilDone(state) + val extState = rewriter.rewriteUntilDone(new SymbState(extFun, arena, Binding())) assert(solverContext.sat()) - assertTlaExAndRestore(rewriter, extState.setRex(tla.eql(tla.int(4), tla.appFun(newFun, tla.int(3))))) + val eq1 = eql(int(4), appFun(extFun, int(3)) ? "i") + .typed(types, "b") + assertTlaExAndRestore(rewriter, extState.setRex(eq1)) } - test("""SE-FUN-APP[1-3]: [x \in {3} |-> {1, x}][3] ~~> $C$k""") { - def mkSet(elems: TlaEx*) = OperEx(TlaSetOper.enumSet, elems: _*) - - val set = tla.enumSet(tla.int(3)) - val mapping = tla.enumSet(tla.int(1), tla.name("x")) - val fun = tla.funDef(mapping, tla.name("x"), set) - val app = OperEx(TlaFunOper.app, fun, tla.int(3)) + test("""[x \in {3} |-> {1, x}][3]""") { + val set = enumSet(int(3)) + val mapping = enumSet(int(1), name("x") ? "i") + val fun = funDef(mapping ? "I", name("x") ? "i", set ? "I") + val app = appFun(fun ? "i_to_I", int(3)) + .typed(types, "I") + val appEq = eql(app, enumSet(int(1), int(3)) ? "I") + .typed(types, "b") val state = new SymbState(app, arena, Binding()) val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case membershipEx @ NameEx(name) => - assert(solverContext.sat()) - val cell = nextState.arena.findCellByName(name) - cell.cellType match { - case FinSetT(IntT()) => - () // type OK, check equality below - - case _ => - fail("Unexpected type") - } - - case _ => - fail("Unexpected rewriting result") - } - rewriter.push() - val appEq = tla.eql(nextState.ex, tla.enumSet(tla.int(1), tla.int(3))) - val eqState = nextState.setRex(appEq) - create().rewriteUntilDone(eqState).ex match { - case eqEx @ NameEx(name) => - solverContext.assertGroundExpr(eqEx) - assert(solverContext.sat()) - - case _ => - fail("Unexpected rewriting result") - } - rewriter.pop() - val appNeq = tla.not(tla.eql(nextState.ex, tla.enumSet(tla.int(1), tla.int(3)))) - val neqState = nextState.setRex(appNeq) - rewriter.rewriteUntilDone(neqState).ex match { - case eqEx @ NameEx(name) => - solverContext.assertGroundExpr(eqEx) - assert(!solverContext.sat()) - - case _ => - fail("Unexpected rewriting result") - } + assertTlaExAndRestore(rewriter, state.setRex(appEq)) } - test("""SE-FUN-APP[1-3]: [x \in {} |-> x][3]""") { - // regression: function application with an empty domain should not crash - val set = tla.withType(tla.enumSet(), AnnotationParser.toTla(FinSetT(IntT()))) - val fun = tla.funDef(tla.name("x"), tla.name("x"), set) - val app = OperEx(TlaFunOper.app, fun, tla.int(3)) + test("""[x \in {} |-> x][3]""") { + // regression: function application with an empty domain should not crash. + // The result of this function is undefined in TLA+. + val fun = funDef(name("x") ? "i", name("x") ? "i", enumSet() ? "I") + val app = appFun(fun ? "i_to_i", int(3)) + .typed(types, "i") val state = new SymbState(app, arena, Binding()) val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) + val _ = rewriter.rewriteUntilDone(state) assert(solverContext.sat()) } - test("""SE-FUN-EQ4: [y \in BOOLEAN |-> ~y] = [x \in BOOLEAN |-> ~x]""") { - val fun1 = tla.funDef(tla.not(tla.name("y")), tla.name("y"), ValEx(TlaBoolSet)) - val fun2 = tla.funDef(tla.not(tla.name("x")), tla.name("x"), ValEx(TlaBoolSet)) - val state = new SymbState(tla.eql(fun1, fun2), arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case membershipEx @ NameEx(name) => - rewriter.push() - solverContext.assertGroundExpr(membershipEx) - assert(solverContext.sat()) - rewriter.pop() - solverContext.assertGroundExpr(tla.not(membershipEx)) - val failureOccurs = tla.or(nextState.arena.findCellsByType(FailPredT()).map(_.toNameEx): _*) - solverContext.assertGroundExpr(tla.not(failureOccurs)) - assertUnsatOrExplain(rewriter, nextState) + test("""[y \in BOOLEAN |-> ~y] = [x \in BOOLEAN |-> ~x]""") { + val fun1 = funDef(not(name("y") ? "b") ? "b", name("y") ? "b", booleanSet() ? "B") + .typed(types, "b_to_b") + val fun2 = funDef(not(name("x") ? "b") ? "b", name("x") ? "b", booleanSet() ? "B") + .typed(types, "b_to_b") + val eq1 = eql(fun1, fun2) + .typed(types, "b") - case _ => - fail("Unexpected rewriting result") - } - } - - test("""SE-FUN-NE: ~([y \in BOOLEAN |-> ~y] = [x \in BOOLEAN |-> ~x])""") { - val fun1 = tla.funDef(tla.not(tla.name("y")), tla.name("y"), ValEx(TlaBoolSet)) - val fun2 = tla.funDef(tla.not(tla.name("x")), tla.name("x"), ValEx(TlaBoolSet)) - val state = new SymbState(tla.not(tla.eql(fun1, fun2)), arena, Binding()) + val state = new SymbState(eq1, arena, Binding()) val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case membershipEx @ NameEx(name) => - val failureOccurs = tla.or(nextState.arena.findCellsByType(FailPredT()).map(_.toNameEx): _*) - solverContext.assertGroundExpr(tla.not(failureOccurs)) - rewriter.push() - solverContext.assertGroundExpr(membershipEx) - assert(!solverContext.sat()) - rewriter.pop() - solverContext.assertGroundExpr(tla.not(membershipEx)) - assert(solverContext.sat()) - - case _ => - fail("Unexpected rewriting result") - } - } - - test("""SE-FUN-NE: ~([y \in BOOLEAN |-> ~y] = [x \in BOOLEAN |-> x])""") { - val fun1 = tla.funDef(tla.not(tla.name("y")), tla.name("y"), ValEx(TlaBoolSet)) - val fun2 = tla.funDef(tla.name("x"), tla.name("x"), ValEx(TlaBoolSet)) - val state = new SymbState(tla.not(tla.eql(fun1, fun2)), arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case membershipEx @ NameEx(name) => - val failureOccurs = tla.or(nextState.arena.findCellsByType(FailPredT()).map(_.toNameEx): _*) - solverContext.assertGroundExpr(tla.not(failureOccurs)) - rewriter.push() - solverContext.assertGroundExpr(membershipEx) - val isSat = solverContext.sat() - assert(isSat) - rewriter.pop() - solverContext.assertGroundExpr(tla.not(membershipEx)) - val isUnsat = !solverContext.sat() - assert(isUnsat) - - case _ => - fail("Unexpected rewriting result") - } + assertTlaExAndRestore(rewriter, state) } // a function returning a function - test("""SE-FUN-APP[1-3]: [x \in {3} |-> [y \in BOOLEAN |-> ~y]][3] ~~> $C$k""") { - val set = tla.enumSet(tla.int(3)) - val boolNegFun = tla.funDef(tla.not(tla.name("y")), tla.name("y"), ValEx(TlaBoolSet)) + test("""[x \in {3} |-> [y \in BOOLEAN |-> ~y]][3]""") { + val boolNegFun = funDef(not(name("y") ? "b") ? "b", name("y") ? "b", booleanSet() ? "B") + .typed(types, "b_to_b") - val fun = tla.funDef(boolNegFun, tla.name("x"), set) - val app = OperEx(TlaFunOper.app, fun, tla.int(3)) + val fun = funDef(boolNegFun, name("x") ? "i", enumSet(int(3)) ? "I") + .typed(types, "i_to_b_to_b") + val app = appFun(fun, int(3)) + .typed(types, "b_to_b") - val state = new SymbState(app, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case membershipEx @ NameEx(name) => - assert(solverContext.sat()) - val cell = nextState.arena.findCellByName(name) - cell.cellType match { - case FunT(FinSetT(BoolT()), BoolT()) => - () // type OK, check equality below - - case _ => - fail("Unexpected type: " + cell.cellType) - } + val appEq = eql(app, boolNegFun) + .typed(BoolT1()) - case _ => - fail("Unexpected rewriting result") - } - rewriter.push() - val appEq = tla.eql(nextState.ex, boolNegFun) - val eqState = rewriter.rewriteUntilDone(nextState.setRex(appEq)) - eqState.ex match { - case eqEx @ NameEx(name) => - solverContext.assertGroundExpr(eqEx) - val isSat = solverContext.sat() - assert(isSat) - - case _ => - fail("Unexpected rewriting result") - } - rewriter.pop() - rewriter.push() - val appNeq = tla.not(tla.eql(nextState.ex, boolNegFun)) - val neqState = rewriter.rewriteUntilDone(nextState.setRex(appNeq)) - neqState.ex match { - case neqEx @ NameEx(name) => - solverContext.assertGroundExpr(neqEx) - val failureOccurs = tla.or(neqState.arena.findCellsByType(FailPredT()).map(_.toNameEx): _*) - solverContext.assertGroundExpr(tla.not(failureOccurs)) - assertUnsatOrExplain(rewriter, neqState) - - case _ => - fail("Unexpected rewriting result") - } - rewriter.pop() + val state = new SymbState(appEq, arena, Binding()) + val rewriter = create() + assertTlaExAndRestore(rewriter, state) } - test("""SE-FUN-APP[1-4]: [x \in {1, 2} |-> IF x = 1 THEN 11 ELSE 2 * x][1] ~~> $C$fun""") { - val set = tla.enumSet(tla.int(1), tla.int(2)) - val pred = tla.eql(tla.name("x"), tla.int(1)) - val ite = tla.ite(pred, tla.int(11), tla.mult(tla.int(2), tla.name("x"))) - val iteFun = tla.funDef(ite, tla.name("x"), set) - val iteFunElem = tla.appFun(iteFun, tla.int(1)) - val iteFunElemNe11 = tla.not(tla.eql(iteFunElem, tla.int(11))) + test("""[x \in {1, 2} |-> IF x = 1 THEN 11 ELSE 2 * x][1]""") { + val set = enumSet(int(1), int(2)) + val pred = eql(name("x") ? "i", int(1)) + val ifThenElse = ite(pred ? "b", int(11), mult(int(2), name("x") ? "i") ? "i") + val iteFun = funDef(ifThenElse ? "i", name("x") ? "i", set ? "I") + val iteFunElem = appFun(iteFun ? "i_to_i", int(1)) + val iteFunElemNe11 = eql(iteFunElem ? "i", int(11)) + .typed(types, "b") val state = new SymbState(iteFunElemNe11, arena, Binding()) val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case resFunEx @ NameEx(name) => - solverContext.assertGroundExpr(resFunEx) - val failureOccurs = tla.or(nextState.arena.findCellsByType(FailPredT()).map(_.toNameEx): _*) - solverContext.assertGroundExpr(tla.not(failureOccurs)) - val isSat = solverContext.sat() - assert(!isSat) - - case _ => - fail("Unexpected rewriting result") - } + assertTlaExAndRestore(rewriter, state) } - test("""SE-FUN-UPD[1-4]: [[x \in {1, 2} |-> 2 * x] EXCEPT ![1] = 11] ~~> $C$fun""") { - val set = tla.enumSet(tla.int(1), tla.int(2)) - val mapExpr = tla.mult(tla.int(2), tla.name("x")) - val fun = tla.funDef(mapExpr, tla.name("x"), set) - - val except = tla.except(fun, tla.tuple(tla.int(1)), tla.int(11)) - val state = new SymbState(except, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case resFunEx @ NameEx(name) => - // check the function domain and co-domain - val resFun = nextState.asCell - // no domain anymore - // val dom = nextState.arena.getDom(resFun) - // assert(nextState.arena.getHas(dom).size == 2) - val cdm = nextState.arena.getCdm(resFun) - val cdmSize = nextState.arena.getHas(cdm).size - assert(cdmSize == 2 || cdmSize == 3) // the co-domain can be overapproximated - - case _ => - fail("Unexpected rewriting result") - } - - val exceptFun = nextState.asCell + test("""[[x \in {1, 2} |-> 2 * x] EXCEPT ![1] = 11]""") { + val set = enumSet(int(1), int(2)) + val mapExpr = mult(int(2), name("x") ? "i") + val fun = funDef(mapExpr ? "i", name("x") ? "i", set ? "I") + .typed(types, "i_to_i") - val resFun1Ne11 = tla.not(tla.eql(tla.appFun(nextState.ex, tla.int(1)), tla.int(11))) - val cmpState = rewriter.rewriteUntilDone(nextState.setRex(resFun1Ne11)) + val newFun = except(fun, tuple(int(1)) ? "(i)", int(11)) + .typed(types, "i_to_i") - // compare - rewriter.push() - - // make sure that not equals gives us sat - cmpState.ex match { - case neqEx @ NameEx(name) => - solverContext.assertGroundExpr(neqEx) - /* - // not using failure predicates anymore - val failureOccurs = tla.or(cmpState.arena.findCellsByType(FailPredT()).map(_.toNameEx): _*) - solverContext.assertGroundExpr(tla.not(failureOccurs)) - */ - assertUnsatOrExplain(rewriter, cmpState) - - case _ => - fail("Unexpected rewriting result") - } - } - - // In general, the index is a tuple; tla-import gives us a singleton tuple. - test("""SE-FUN-UPD[1-4]: [[x \in {1, 2} |-> 2 * x] EXCEPT ![(1)] = 11] ~~> $C$fun""") { - val set = tla.enumSet(tla.int(1), tla.int(2)) - val mapExpr = tla.mult(tla.int(2), tla.name("x")) - val fun = tla.funDef(mapExpr, tla.name("x"), set) - - val except = tla.except(fun, tla.tuple(tla.int(1)), tla.int(11)) - val state = new SymbState(except, arena, Binding()) + val state = new SymbState(newFun, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case resFunEx @ NameEx(name) => - // check the function domain and co-domain - val resFun = nextState.arena.findCellByName(name) - // no domain anymore - // val dom = nextState.arena.getDom(resFun) - // assert(nextState.arena.getHas(dom).size == 2) - val cdm = nextState.arena.getCdm(resFun) - val cdmSize = nextState.arena.getHas(cdm).size - assert(cdmSize == 2 || cdmSize == 3) // the co-domain can be overapproximated - - case _ => - fail("Unexpected rewriting result") - } - - val exceptFun = nextState.arena.findCellByNameEx(nextState.ex) - - val resFun1Ne11 = tla.not(tla.eql(tla.appFun(nextState.ex, tla.int(1)), tla.int(11))) - val cmpState = rewriter.rewriteUntilDone(nextState.setRex(resFun1Ne11)) - - // compare - rewriter.push() - - // make sure that not equals gives us sat - cmpState.ex match { - case neqEx @ NameEx(name) => - solverContext.assertGroundExpr(neqEx) - val failureOccurs = tla.or(cmpState.arena.findCellsByType(FailPredT()).map(_.toNameEx): _*) - solverContext.assertGroundExpr(tla.not(failureOccurs)) - assertUnsatOrExplain(rewriter, cmpState) - - case _ => - fail("Unexpected rewriting result") - } - } - test("""SE-FUN-UPD[1-4] and singleton tuple: [[x \in {1, 2} |-> 2 * x] EXCEPT ![(1)] = 11] ~~> $C$fun""") { - // singleton tuples in EXCEPT are erased and converted into the tuple element - val set = tla.enumSet(tla.int(1), tla.int(2)) - val mapExpr = tla.mult(tla.int(2), tla.name("x")) - val fun = tla.funDef(mapExpr, tla.name("x"), set) - - val except = tla.except(fun, tla.tuple(tla.int(1)), tla.int(11)) - val state = new SymbState(except, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { - case resFunEx @ NameEx(name) => + case NameEx(_) => // check the function domain and co-domain - val resFun = nextState.arena.findCellByName(name) - // no domain anymore - // val dom = nextState.arena.getDom(resFun) - // assert(nextState.arena.getHas(dom).size == 2) + val resFun = nextState.asCell + // check the function co-domain val cdm = nextState.arena.getCdm(resFun) val cdmSize = nextState.arena.getHas(cdm).size assert(cdmSize == 2 || cdmSize == 3) // the co-domain can be overapproximated @@ -565,131 +340,51 @@ class TestSymbStateRewriterFun extends RewriterBase with TestingPredefs { fail("Unexpected rewriting result") } - val exceptFun = nextState.arena.findCellByNameEx(nextState.ex) - - val resFun1Ne11 = tla.not(tla.eql(tla.appFun(nextState.ex, tla.int(1)), tla.int(11))) - val cmpState = rewriter.rewriteUntilDone(nextState.setRex(resFun1Ne11)) - - // compare - rewriter.push() - - // make sure that not equals gives us sat - cmpState.ex match { - case neqEx @ NameEx(name) => - solverContext.assertGroundExpr(neqEx) - val failureOccurs = tla.or(cmpState.arena.findCellsByType(FailPredT()).map(_.toNameEx): _*) - solverContext.assertGroundExpr(tla.not(failureOccurs)) - assertUnsatOrExplain(rewriter, cmpState) - - case _ => - fail("Unexpected rewriting result") - } + val resFun1eq11 = eql(appFun(newFun, int(1)) ? "i", int(11)) + .typed(types, "b") + assertTlaExAndRestore(rewriter, nextState.setRex(resFun1eq11)) } - test("""SE-FUN-UPD[1-4], singleton tuple, and const: [[x \in {"a", "b"} |-> 3] EXCEPT ![("a")] = 11] ~~> $C$fun""") { - // singleton tuples in EXCEPT are erased and converted into the tuple element - val set = tla.enumSet(tla.str("a"), tla.str("b")) - val mapExpr = tla.int(3) - val fun = tla.funDef(mapExpr, tla.name("x"), set) - - val except = tla.except(fun, tla.tuple(tla.str("a")), tla.int(11)) - val state = new SymbState(except, arena, Binding()) + test("""[[x \in {"a", "b"} |-> 3] EXCEPT !["a"] = 11]""") { + val set = enumSet(str("a"), str("b")) + val mapExpr = int(3) + val fun = funDef(mapExpr ? "i", name("x") ? "s", set ? "S") + .typed(types, "s_to_i") + val newFun = except(fun, tuple(str("a")) ? "(s)", int(11)) + .typed(types, "s_to_i") + val resFun1eq11 = eql(appFun(newFun, str("a")) ? "i", int(11)) + .typed(types, "b") + + val state = new SymbState(newFun, arena, Binding()) val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case resFunEx @ NameEx(name) => - // check the function domain and co-domain - val resFun = nextState.arena.findCellByName(name) - // no domain anymore - // val dom = nextState.arena.getDom(resFun) - // assert(nextState.arena.getHas(dom).size == 2) - val cdm = nextState.arena.getCdm(resFun) - val cdmSize = nextState.arena.getHas(cdm).size - assert(cdmSize == 2 || cdmSize == 3) // the co-domain can be overapproximated - - case _ => - fail("Unexpected rewriting result") - } - - val exceptFun = nextState.arena.findCellByNameEx(nextState.ex) - - val resFun1Ne11 = tla.not(tla.eql(tla.appFun(nextState.ex, tla.str("a")), tla.int(11))) - val cmpState = rewriter.rewriteUntilDone(nextState.setRex(resFun1Ne11)) - - // compare - rewriter.push() - - // make sure that not equals gives us sat - cmpState.ex match { - case neqEx @ NameEx(name) => - solverContext.assertGroundExpr(neqEx) - val failureOccurs = tla.or(cmpState.arena.findCellsByType(FailPredT()).map(_.toNameEx): _*) - solverContext.assertGroundExpr(tla.not(failureOccurs)) - assertUnsatOrExplain(rewriter, cmpState) - - case _ => - fail("Unexpected rewriting result") - } + assertTlaExAndRestore(rewriter, state.setRex(resFun1eq11)) } test("""fun in a set: \E x \in {[y \in BOOLEAN |-> ~y]}: x[FALSE]""") { // this test was failing in the buggy implementation with PICK .. FROM and FUN-MERGE - val fun1 = tla.funDef(tla.not(tla.name("y")), tla.name("y"), ValEx(TlaBoolSet)) - val exists = - OperEx(BmcOper.skolem, tla.exists(tla.name("x"), tla.enumSet(fun1), tla.appFun(NameEx("x"), tla.bool(false)))) + val fun1 = funDef(not(name("y") ? "b") ? "b", name("y") ? "b", booleanSet() ? "B") + .typed(types, "b_to_b") + val existsForm = + apalacheSkolem(exists(name("x") ? "b_to_b", enumSet(fun1) ? "b_TO_b", + appFun(name("x") ? "b_to_b", bool(false)) ? "b") ? "b") + .typed(types, "b") - // here, we have to overred FreeExistentialsStore, and thus cannot use SymbStateRewriterAuto - val typeFinder = new TrivialTypeFinder() - val rewriter = new SymbStateRewriterImpl(solverContext, typeFinder) - typeFinder.inferAndSave(exists) + val rewriter = new SymbStateRewriterImpl(solverContext) - val state = new SymbState(exists, arena, Binding()) - val nextState = rewriter.rewriteUntilDone(state) - val failureOccurs = tla.or(nextState.arena.findCellsByType(FailPredT()).map(_.toNameEx): _*) - solverContext.assertGroundExpr(tla.not(failureOccurs)) - assertTlaExAndRestore(rewriter, nextState) - // check failure predicates - solverContext.assertGroundExpr(nextState.ex) - val failure = tla.or(nextState.arena.findCellsByType(FailPredT()).map(_.toNameEx): _*) - solverContext.assertGroundExpr(failure) - assert(!solverContext.sat()) + val state = new SymbState(existsForm, arena, Binding()) + assertTlaExAndRestore(rewriter, state) } - test("""SE-FUN-DOMAIN: DOMAIN [x \in {1,2,3} |-> x / 2: ]""") { - val set = tla.enumSet(tla.int(1), tla.int(2), tla.int(3)) - val mapping = OperEx(TlaArithOper.div, NameEx("x"), tla.int(2)) - val fun = tla.funDef(mapping, tla.name("x"), set) - val dom = tla.dom(fun) - val eq = tla.eql(dom, set) + test("""DOMAIN [x \in {1,2,3} |-> x / 2: ]""") { + val set = enumSet(int(1), int(2), int(3)) + val mapping = div(name("x"), int(2)) + val fun = funDef(mapping ? "i", name("x") ? "i", set ? "I") + val domain = dom(fun ? "i_to_i") + val eq = eql(domain ? "I", set ? "I") + .typed(types, "b") val rewriter = create() val state = new SymbState(eq, arena, Binding()) assertTlaExAndRestore(rewriter, state) } - - // TrivialTypeFinder does not support let-in and operator declarations - ignore("""SE-SET-APP[1-2]: LET X = {1, 2} \cap {2} IN [y \in X |-> TRUE][2] ~~> $B$k""") { - // regression - val fun = tla.funDef(tla.bool(true), tla.name("y"), tla.name("Oper:X")) - val app = tla.appFun(fun, tla.int(2)) - val ex = tla.letIn(app, - tla.declOp("X", tla.cap(tla.enumSet(tla.int(1), tla.int(2)), tla.enumSet(tla.int(2)))).untypedOperDecl()) - - val state = new SymbState(ex, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case membershipEx @ NameEx(name) => - assert(solverContext.sat()) // it should be sat - rewriter.push() - val failPreds = nextState.arena.findCellsByType(FailPredT()) - val failureOccurs = tla.or(failPreds.map(_.toNameEx): _*) - solverContext.assertGroundExpr(tla.not(failureOccurs)) - assert(solverContext.sat()) // no deadlock - - case _ => - fail("Unexpected rewriting result") - } - } - } diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterFunSet.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterFunSet.scala index b25e413bf0..62b077db80 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterFunSet.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterFunSet.scala @@ -1,19 +1,48 @@ package at.forsyte.apalache.tla.bmcmt import at.forsyte.apalache.tla.bmcmt.types._ -import at.forsyte.apalache.tla.lir.{NameEx, OperEx, TlaEx} -import at.forsyte.apalache.tla.lir.convenience.tla -import at.forsyte.apalache.tla.lir.oper.BmcOper -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.lir.TypedPredefs._ +import at.forsyte.apalache.tla.lir.convenience.tla._ +import at.forsyte.apalache.tla.lir._ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestSymbStateRewriterFunSet extends RewriterBase { - test("""SE-FUNSET1: [{1, 2, 3} -> {FALSE, TRUE}]""") { - val domain = tla.enumSet(tla.int(1), tla.int(2), tla.int(3)) - val codomain = tla.enumSet(tla.bool(false), tla.bool(true)) - val state = new SymbState(tla.funSet(domain, codomain), arena, Binding()) + val types = + Map( + "b" -> BoolT1(), + "B" -> SetT1(BoolT1()), + "BB" -> SetT1(SetT1(BoolT1())), + "i" -> IntT1(), + "I" -> SetT1(IntT1()), + "(i)" -> TupT1(IntT1()), + "i_to_i" -> FunT1(IntT1(), IntT1()), + "i_to_I" -> FunT1(IntT1(), SetT1(IntT1())), + "i_TO_i" -> SetT1(FunT1(IntT1(), IntT1())), + "r" -> RecT1("a" -> IntT1()), + "s" -> StrT1(), + "S" -> SetT1(StrT1()), + "(s)" -> TupT1(StrT1()), + "i_to_s" -> FunT1(StrT1(), IntT1()), + "s_to_i" -> FunT1(IntT1(), StrT1()), + "i_to_r" -> FunT1(IntT1(), RecT1("a" -> IntT1())), + "b_to_b" -> FunT1(BoolT1(), BoolT1()), + "b_TO_b" -> SetT1(FunT1(BoolT1(), BoolT1())), + "i_to_b" -> FunT1(IntT1(), BoolT1()), + "i_to_B" -> FunT1(IntT1(), SetT1(BoolT1())), + "i_TO_B" -> FunT1(IntT1(), SetT1(BoolT1())), + "i_TO_i_to_B" -> SetT1(FunT1(IntT1(), FunT1(IntT1(), SetT1(BoolT1())))), + "i_to_i_to_B" -> FunT1(IntT1(), FunT1(IntT1(), SetT1(BoolT1()))), + "i_to_b_to_b" -> FunT1(IntT1(), FunT1(BoolT1(), BoolT1())) + ) + + test("""[{1, 2, 3} -> {FALSE, TRUE}]""") { + val domain = enumSet(int(1), int(2), int(3)) + val codomain = enumSet(bool(false), bool(true)) + val fs = funSet(domain ? "I", codomain ? "B") + .typed(types, "i_to_b") + val state = new SymbState(fs, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { @@ -35,10 +64,13 @@ class TestSymbStateRewriterFunSet extends RewriterBase { } } - test("""SE-FUNSET2: [{1, 2} -> Expand(SUBSET {FALSE, TRUE})]""") { - val domain = tla.enumSet(tla.int(1), tla.int(2)) - val codomain = OperEx(BmcOper.expand, tla.powSet(tla.enumSet(tla.bool(false), tla.bool(true)))) - val state = new SymbState(tla.funSet(domain, codomain), arena, Binding()) + test("""[{1, 2} -> Expand(SUBSET {FALSE, TRUE})]""") { + val domain = enumSet(int(1), int(2)) + val codomain = apalacheExpand(powSet(enumSet(bool(false), bool(true)) ? "B") ? "BB") + val fs = funSet(domain ? "I", codomain ? "BB") + .typed(types, "i_to_B") + + val state = new SymbState(fs, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { @@ -58,13 +90,16 @@ class TestSymbStateRewriterFunSet extends RewriterBase { } // the existential over a function set should work without expanding the powerset! - test("""SE-FUNSET2: Skolem(\E f \in [{1, 2} -> SUBSET {FALSE, TRUE}]: g' <- f)""") { - val domain = tla.enumSet(tla.int(1), tla.int(2)) - val codomain = tla.powSet(tla.enumSet(tla.bool(false), tla.bool(true))) - val pred = tla.assignPrime(tla.name("g"), tla.name("f")) - val exists = - tla.exists(tla.name("f"), tla.funSet(domain, codomain), pred) - val skolem = OperEx(BmcOper.skolem, exists) + test("""Skolem(\E f \in [{1, 2} -> SUBSET {FALSE, TRUE}]: g' <- f)""") { + val domain = enumSet(int(1), int(2)) ? "I" + val codomain = powSet(enumSet(bool(false), bool(true)) ? "B") ? "BB" + val pred = assign(prime(name("g") ? "i_to_B") ? "i_to_B", name("f") ? "i_to_B") + .typed(types, "b") + val existsForm = + exists(name("f") ? "i_to_B", funSet(domain, codomain) ? "i_to_B", pred) + .typed(types, "b") + val skolem = apalacheSkolem(existsForm) + .typed(BoolT1()) val state = new SymbState(skolem, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -75,13 +110,15 @@ class TestSymbStateRewriterFunSet extends RewriterBase { } // the existential over a function set should work without expanding the powerset! - test("""SE-FUNSET2: Skolem(\E f \in [{1, 2} -> SUBSET {FALSE}]: f[1] = {TRUE})""") { - val domain = tla.enumSet(tla.int(1), tla.int(2)) - val codomain = tla.powSet(tla.enumSet(tla.bool(false))) - val pred = tla.eql(tla.appFun(tla.name("f"), tla.int(1)), tla.enumSet(tla.bool(true))) - val exists = - tla.exists(tla.name("f"), tla.funSet(domain, codomain), pred) - val skolem = OperEx(BmcOper.skolem, exists) + test("""Skolem(\E f \in [{1, 2} -> SUBSET {FALSE}]: f[1] = {TRUE})""") { + val domain = enumSet(int(1), int(2)) ? "I" + val codomain = powSet(enumSet(bool(false)) ? "B") ? "BB" + val pred = eql(appFun(name("f") ? "i_to_B", int(1)) ? "B", enumSet(bool(true)) ? "B") + .typed(types, "b") + val existsForm = + exists(name("f") ? "i_to_B", funSet(domain, codomain) ? "i_TO_B", pred) + val skolem = apalacheSkolem(existsForm ? "b") + .typed(types, "b") val state = new SymbState(skolem, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -90,16 +127,20 @@ class TestSymbStateRewriterFunSet extends RewriterBase { } // An existential over a function set that returns a function set to a powerset. Does it blow up your mind? :-) - test("""SE-FUNSET2: Skolem(\E f \in [{1, 2} -> [{3} -> SUBSET {FALSE, TRUE}]]: g' <- f)""") { - val domain1 = tla.enumSet(tla.int(1), tla.int(2)) - val domain2 = tla.enumSet(tla.int(3)) - val codomain2 = tla.powSet(tla.enumSet(tla.bool(false), tla.bool(true))) - val codomain1 = tla.funSet(domain2, codomain2) - val funset = tla.funSet(domain1, codomain1) - val pred = tla.assignPrime(tla.name("g"), tla.name("f")) - val exists = - tla.exists(tla.name("f"), funset, pred) - val skolem = OperEx(BmcOper.skolem, exists) + test("""Skolem(\E f \in [{1, 2} -> [{3} -> SUBSET {FALSE, TRUE}]]: g' <- f)""") { + val domain1 = enumSet(int(1), int(2)) + val domain2 = enumSet(int(3)) + val codomain2 = powSet(enumSet(bool(false), bool(true)) ? "B") ? "BB" + val codomain1 = funSet(domain2 ? "I", codomain2) ? "i_TO_B" + val funset = funSet(domain1 ? "I", codomain1) + .typed(types, "i_TO_i_to_B") + val pred = assign(prime(name("g") ? "i_to_i_to_B") ? "i_to_i_to_B", name("f") ? "i_to_i_to_B") + .typed(types, "b") + val existsForm = + exists(name("f") ? "i_to_i_to_B", funset ? "i_TO_i_to_B", pred) + .typed(types, "b") + val skolem = apalacheSkolem(existsForm) + .typed(types, "b") val state = new SymbState(skolem, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -110,80 +151,85 @@ class TestSymbStateRewriterFunSet extends RewriterBase { } // this should be fixed by implementing #91 - test("""SE-FUNSET2: [x \in {1, 2} |-> {x = 1}] \in [{1, 2} -> SUBSET {FALSE, TRUE}]""") { - val domain = tla.enumSet(tla.int(1), tla.int(2)) - val codomain = tla.powSet(tla.enumSet(tla.bool(false), tla.bool(true))) - val funset = tla.funSet(domain, codomain) - val fun = tla.funDef(tla.enumSet(tla.eql(tla.name("x"), tla.int(1))), tla.name("x"), domain) - val state = new SymbState(tla.in(fun, funset), arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case NameEx(name) => - solverContext.push() - solverContext.assertGroundExpr(nextState.ex) - assert(solverContext.sat()) - solverContext.pop() - solverContext.push() - solverContext.assertGroundExpr(tla.not(nextState.ex)) - solverContext.pop() - - case _ => - fail("Unexpected rewriting result") - } + test("""[x \in {1, 2} |-> {x = 1}] \in [{1, 2} -> SUBSET {FALSE, TRUE}]""") { + val domain = enumSet(int(1), int(2)) ? "I" + val codomain = powSet(enumSet(bool(false), bool(true)) ? "B") ? "BB" + val funset = funSet(domain, codomain) ? "i_to_B" + val fun = funDef(enumSet(eql(name("x") ? "i", int(1)) ? "b") ? "B", name("x") ? "i", domain) + .typed(types, "i_to_B") + val funInFunSet = in(fun, funset) + .typed(types, "b") + val state = new SymbState(funInFunSet, arena, Binding()) + assertTlaExAndRestore(create(), state) } - test("""SE-FUNSET2: [x \in {1, 2} |-> 3] \in [{1, 2} -> {3, 4}]""") { - val domain = tla.enumSet(tla.int(1), tla.int(2)) - val codomain = tla.enumSet(tla.int(3), tla.int(4)) - val funset = tla.funSet(domain, codomain) - val fun = tla.funDef(tla.int(3), tla.name("x"), domain) - val state = new SymbState(tla.in(fun, funset), arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - assert(solverContext.sat()) - assertTlaExAndRestore(rewriter, nextState) + test("""[x \in {1, 2} |-> 3] \in [{1, 2} -> {3, 4}]""") { + val domain = enumSet(int(1), int(2)) ? "I" + val codomain = enumSet(int(3), int(4)) ? "I" + val funset = funSet(domain, codomain) ? "i_TO_i" + val fun = funDef(int(3), name("x") ? "i", domain) + .typed(types, "i_to_i") + val funInFunSet = in(fun, funset) + .typed(types, "b") + val state = new SymbState(funInFunSet, arena, Binding()) + assertTlaExAndRestore(create(), state) } // this should be redundant in the presence of #91 - test("""SE-FUNSET2: [x \in {0, 1, 2} \ {0} |-> 3] \in [{1, 2} -> {3, 4}]""") { + test("""[x \in {0, 1, 2} \ {0} |-> 3] \in [{1, 2} -> {3, 4}]""") { // although 0 is in the function domain at the arena level, it does not belong to the set difference def setminus(set: TlaEx, intVal: Int): TlaEx = { - tla.filter(tla.name("t"), set, tla.not(tla.eql(tla.name("t"), tla.int(intVal)))) + filter(name("t") ? "i", set, not(eql(name("t") ? "i", int(intVal)) ? "b") ? "b") + .typed(types, "I") } - val domain1 = setminus(tla.enumSet(0.to(2).map(tla.int): _*), 0) - val domain2 = tla.enumSet(1.to(2).map(tla.int): _*) - val codomain = tla.enumSet(tla.int(3), tla.int(4)) - val funset = tla.funSet(domain2, codomain) - val fun = tla.funDef(tla.int(3), tla.name("x"), domain1) - val state = new SymbState(tla.in(fun, funset), arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - assert(solverContext.sat()) - assertTlaExAndRestore(rewriter, nextState) + val domain1 = setminus( + enumSet(0.to(2).map(int): _*) + .typed(types, "I"), 0) + val domain2 = enumSet(1.to(2).map(int): _*) + .typed(types, "I") + val codomain = enumSet(int(3), int(4)) + .typed(types, "I") + val funset = funSet(domain2, codomain) + .typed(types, "i_TO_i") + val fun = funDef(int(3), name("x") ? "i", domain1) + .typed(types, "i_to_i") + val funInFunSet = in(fun, funset) + .typed(types, "b") + + val state = new SymbState(funInFunSet, arena, Binding()) + assertTlaExAndRestore(create(), state) } // this should be fixed by implementing #91 - test("""SE-FUNSET2: [x \in {1, 2} |-> {TRUE}] \in [{1, 2} -> SUBSET {FALSE}]""") { - val domain = tla.enumSet(tla.int(1), tla.int(2)) - val codomain = tla.powSet(tla.enumSet(tla.bool(false))) - val funset = tla.funSet(domain, codomain) - val fun = tla.funDef(tla.enumSet(tla.bool(true)), tla.name("x"), domain) - val state = new SymbState(tla.in(fun, funset), arena, Binding()) + test("""[x \in {1, 2} |-> {TRUE}] \in [{1, 2} -> SUBSET {FALSE}]""") { + val domain = enumSet(int(1), int(2)) ? "I" + val codomain = powSet(enumSet(bool(false)) ? "B") ? "BB" + val funset = funSet(domain, codomain) + .typed(types, "i_TO_B") + val fun = funDef(enumSet(bool(true)) ? "B", name("x") ? "i", domain) + .typed(types, "i_to_B") + val funInFunSet = in(fun, funset) + .typed(types, "b") + + val state = new SymbState(funInFunSet, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) - assert(solverContext.sat()) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.not(nextState.ex))) + solverContext.assertGroundExpr(nextState.ex) + assert(!solverContext.sat()) } // this should be fixed by implementing #91 - test("""SE-FUNSET with a SUBSET: [x \in {1, 2} |-> {TRUE}] \in [{1, 2} -> SUBSET {FALSE, TRUE}]""") { - val domain = tla.enumSet(tla.int(1), tla.int(2)) - val codomain = tla.powSet(tla.enumSet(tla.bool(false), tla.bool(true))) - val funset = tla.funSet(domain, codomain) - val fun = tla.funDef(tla.enumSet(tla.bool(true)), tla.name("x"), domain) - val state = new SymbState(tla.in(fun, funset), arena, Binding()) + test("""[x \in {1, 2} |-> {TRUE}] \in [{1, 2} -> SUBSET {FALSE, TRUE}]""") { + val domain = enumSet(int(1), int(2)) ? "I" + val codomain = powSet(enumSet(bool(false), bool(true)) ? "B") ? "BB" + val funset = funSet(domain, codomain) + .typed(types, "i_TO_B") + val fun = funDef(enumSet(bool(true)) ? "B", name("x") ? "i", domain) + .typed(types, "i_to_B") + val funInFunSet = in(fun, funset) + .typed(types, "b") + val state = new SymbState(funInFunSet, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) assert(solverContext.sat()) @@ -192,9 +238,11 @@ class TestSymbStateRewriterFunSet extends RewriterBase { // bugfix 27/12/2017 test("""SE-FUNSET1: [0..(5 - 1) -> {FALSE, TRUE}]""") { - val domain = tla.dotdot(tla.int(0), tla.minus(tla.int(5), tla.int(1))) - val codomain = tla.enumSet(tla.bool(false), tla.bool(true)) - val state = new SymbState(tla.funSet(domain, codomain), arena, Binding()) + val domain = dotdot(int(0), minus(int(5), int(1)) ? "i") ? "I" + val codomain = enumSet(bool(false), bool(true)) ? "B" + val fs = funSet(domain, codomain) + .typed(types, "i_to_b") + val state = new SymbState(fs, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterInt.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterInt.scala index bdcf9dd408..f80e299b3a 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterInt.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterInt.scala @@ -1,42 +1,49 @@ package at.forsyte.apalache.tla.bmcmt import at.forsyte.apalache.tla.bmcmt.types.IntT -import at.forsyte.apalache.tla.lir.convenience.tla -import at.forsyte.apalache.tla.lir.oper._ -import at.forsyte.apalache.tla.lir.values.TlaInt -import at.forsyte.apalache.tla.lir.{NameEx, OperEx, TlaEx, ValEx} -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.lir.TypedPredefs._ +import at.forsyte.apalache.tla.lir.convenience.tla._ +import at.forsyte.apalache.tla.lir.{BoolT1, IntT1, NameEx, SetT1} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestSymbStateRewriterInt extends RewriterBase { - test("SE-INT-CELL-EQ1: $C$_i: Int = $C$_j: Int ~~> valInt(...) = valInt(...)") { + private val intTypes = Map("i" -> IntT1(), "I" -> SetT1(IntT1()), "b" -> BoolT1()) + + test("$C$_i: Int = $C$_j: Int") { arena = arena.appendCell(IntT()) val leftCell = arena.topCell arena = arena.appendCell(IntT()) val rightCell = arena.topCell - val state = new SymbState(OperEx(TlaOper.eq, leftCell.toNameEx, rightCell.toNameEx), arena, Binding()) + val eq1 = eql(leftCell.toNameEx ? "i", rightCell.toNameEx ? "i") + .typed(intTypes, "b") + val state = new SymbState(eq1, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { - case predEx @ NameEx(name) => + case predEx @ NameEx(_) => assert(solverContext.sat()) - solverContext.assertGroundExpr(OperEx(TlaOper.eq, leftCell.toNameEx, ValEx(TlaInt(22)))) + val eq2 = eql(leftCell.toNameEx ? "i", int((22))) + .typed(intTypes, "b") + solverContext.assertGroundExpr(eq2) rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, rightCell.toNameEx, ValEx(TlaInt(22)))) + val eq3 = eql(rightCell.toNameEx ? "i", int(22)) + .typed(intTypes, "b") + solverContext.assertGroundExpr(eq3) rewriter.push() solverContext.assertGroundExpr(predEx) assert(solverContext.sat()) rewriter.pop() rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaBoolOper.not, predEx)) + solverContext.assertGroundExpr(not(predEx ? "b").typed(intTypes, "b")) assert(!solverContext.sat()) rewriter.pop() rewriter.pop() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, rightCell.toNameEx, ValEx(TlaInt(1981)))) + val eq4 = eql(rightCell.toNameEx ? "i", int((1981))).typed(intTypes, "b") + solverContext.assertGroundExpr(eq4) rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaBoolOper.not, predEx)) + solverContext.assertGroundExpr(not(predEx ? "b").typed(intTypes, "b")) assert(solverContext.sat()) rewriter.pop() solverContext.assertGroundExpr(predEx) @@ -47,32 +54,32 @@ class TestSymbStateRewriterInt extends RewriterBase { } } - test("SE-INT-EQ1: $Z$i = $Z$j ~~> $B$k") { + test("$Z$i = $Z$j ~~> $B$k") { arena = arena.appendCell(IntT()) val leftInt = arena.topCell.toNameEx arena = arena.appendCell(IntT()) val rightInt = arena.topCell.toNameEx - val state = new SymbState(OperEx(TlaOper.eq, leftInt, rightInt), arena, Binding()) + val state = new SymbState(eql(leftInt ? "i", rightInt ? "i").typed(intTypes, "b"), arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { case predEx @ NameEx(name) => assert(solverContext.sat()) - solverContext.assertGroundExpr(OperEx(TlaOper.eq, leftInt, ValEx(TlaInt(22)))) + solverContext.assertGroundExpr(eql(leftInt ? "i", int(22)).typed(intTypes, "b")) rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, rightInt, ValEx(TlaInt(22)))) + solverContext.assertGroundExpr(eql(rightInt ? "i", int(22)).typed(intTypes, "b")) rewriter.push() solverContext.assertGroundExpr(predEx) assert(solverContext.sat()) rewriter.pop() rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaBoolOper.not, predEx)) + solverContext.assertGroundExpr(not(predEx ? "b").typed(intTypes, "b")) assert(!solverContext.sat()) rewriter.pop() rewriter.pop() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, rightInt, ValEx(TlaInt(1981)))) + solverContext.assertGroundExpr(eql(rightInt ? "i", int((1981))).typed(intTypes, "b")) rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaBoolOper.not, predEx)) + solverContext.assertGroundExpr(not(predEx ? "b").typed(intTypes, "b")) assert(solverContext.sat()) rewriter.pop() solverContext.assertGroundExpr(predEx) @@ -82,28 +89,29 @@ class TestSymbStateRewriterInt extends RewriterBase { fail("Unexpected rewriting result") } } - test("SE-INT-CELL-CMP1: $C$_i: Int < $C$_j: Int ~~> valInt(...) < valInt(...)") { + test("$C$_i: Int < $C$_j: Int ~~> valInt(...) < valInt(...)") { arena = arena.appendCell(IntT()) val leftCell = arena.topCell arena = arena.appendCell(IntT()) val rightCell = arena.topCell - val state = new SymbState(OperEx(TlaArithOper.lt, leftCell.toNameEx, rightCell.toNameEx), arena, Binding()) + val state = + new SymbState(lt(leftCell.toNameEx ? "i", rightCell.toNameEx ? "i").typed(intTypes, "b"), arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { case cmpEx @ NameEx(name) => assert(solverContext.sat()) solverContext.assertGroundExpr(cmpEx) - solverContext.assertGroundExpr(OperEx(TlaOper.eq, leftCell.toNameEx, ValEx(TlaInt(4)))) + solverContext.assertGroundExpr(eql(leftCell.toNameEx ? "i", int(4)).typed(intTypes, "b")) rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, rightCell.toNameEx, ValEx(TlaInt(22)))) + solverContext.assertGroundExpr(eql(rightCell.toNameEx ? "i", int(22)).typed(intTypes, "b")) assert(solverContext.sat()) rewriter.pop() rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, rightCell.toNameEx, ValEx(TlaInt(4)))) + solverContext.assertGroundExpr(eql(rightCell.toNameEx ? "i", int(4)).typed(intTypes, "b")) assert(!solverContext.sat()) rewriter.pop() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, rightCell.toNameEx, ValEx(TlaInt(3)))) + solverContext.assertGroundExpr(eql(rightCell.toNameEx ? "i", int(3)).typed(intTypes, "b")) assert(!solverContext.sat()) case _ => @@ -111,28 +119,29 @@ class TestSymbStateRewriterInt extends RewriterBase { } } - test("SE-INT-CELL-CMP1: $C$_i: Int <= $C$_j: Int ~~> valInt(...) <= valInt(...)") { + test("$C$_i: Int <= $C$_j: Int ~~> valInt(...) <= valInt(...)") { arena = arena.appendCell(IntT()) val leftCell = arena.topCell arena = arena.appendCell(IntT()) val rightCell = arena.topCell - val state = new SymbState(OperEx(TlaArithOper.le, leftCell.toNameEx, rightCell.toNameEx), arena, Binding()) + val state = + new SymbState(le(leftCell.toNameEx ? "i", rightCell.toNameEx ? "i").typed(intTypes, "b"), arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { case cmpEx @ NameEx(name) => assert(solverContext.sat()) solverContext.assertGroundExpr(cmpEx) - solverContext.assertGroundExpr(OperEx(TlaOper.eq, leftCell.toNameEx, ValEx(TlaInt(4)))) + solverContext.assertGroundExpr(eql(leftCell.toNameEx ? "i", int(4)).typed(intTypes, "b")) rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, rightCell.toNameEx, ValEx(TlaInt(22)))) + solverContext.assertGroundExpr(eql(rightCell.toNameEx ? "i", int(22)).typed(intTypes, "b")) assert(solverContext.sat()) rewriter.pop() rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, rightCell.toNameEx, ValEx(TlaInt(4)))) + solverContext.assertGroundExpr(eql(rightCell.toNameEx ? "i", int(4)).typed(intTypes, "b")) assert(solverContext.sat()) rewriter.pop() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, rightCell.toNameEx, ValEx(TlaInt(3)))) + solverContext.assertGroundExpr(eql(rightCell.toNameEx ? "i", int(3)).typed(intTypes, "b")) assert(!solverContext.sat()) case _ => @@ -140,28 +149,29 @@ class TestSymbStateRewriterInt extends RewriterBase { } } - test("SE-INT-CELL-CMP1: $C$_i: Int > $C$_j: Int ~~> valInt(...) > valInt(...)") { + test("$C$_i: Int > $C$_j: Int ~~> valInt(...) > valInt(...)") { arena = arena.appendCell(IntT()) val leftCell = arena.topCell arena = arena.appendCell(IntT()) val rightCell = arena.topCell - val state = new SymbState(OperEx(TlaArithOper.gt, leftCell.toNameEx, rightCell.toNameEx), arena, Binding()) + val state = + new SymbState(gt(leftCell.toNameEx ? "i", rightCell.toNameEx ? "i").typed(intTypes, "b"), arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { case cmpEx @ NameEx(name) => assert(solverContext.sat()) solverContext.assertGroundExpr(cmpEx) - solverContext.assertGroundExpr(OperEx(TlaOper.eq, leftCell.toNameEx, ValEx(TlaInt(4)))) + solverContext.assertGroundExpr(eql(leftCell.toNameEx ? "i", int(4)).typed(intTypes, "b")) rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, rightCell.toNameEx, ValEx(TlaInt(22)))) + solverContext.assertGroundExpr(eql(rightCell.toNameEx ? "i", int(22)).typed(intTypes, "b")) assert(!solverContext.sat()) rewriter.pop() rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, rightCell.toNameEx, ValEx(TlaInt(4)))) + solverContext.assertGroundExpr(eql(rightCell.toNameEx ? "i", int(4)).typed(intTypes, "b")) assert(!solverContext.sat()) rewriter.pop() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, rightCell.toNameEx, ValEx(TlaInt(3)))) + solverContext.assertGroundExpr(eql(rightCell.toNameEx ? "i", int(3)).typed(intTypes, "b")) assert(solverContext.sat()) case _ => @@ -169,10 +179,10 @@ class TestSymbStateRewriterInt extends RewriterBase { } } - test("SE-INT-CMP1 (composite expressions): 1 + 5 > 6 - 3 ~~> $B$_k") { - val left = OperEx(TlaArithOper.plus, ValEx(TlaInt(1)), ValEx(TlaInt(5))) - val right = OperEx(TlaArithOper.minus, ValEx(TlaInt(6)), ValEx(TlaInt(3))) - val state = new SymbState(OperEx(TlaArithOper.gt, left, right), arena, Binding()) + test("(composite expressions): 1 + 5 > 6 - 3 ~~> $B$_k") { + val left = plus(int(1), int(5)).typed(IntT1()) + val right = minus(int(6), int(3)).typed(IntT1()) + val state = new SymbState(gt(left, right).typed(BoolT1()), arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { @@ -182,7 +192,7 @@ class TestSymbStateRewriterInt extends RewriterBase { solverContext.assertGroundExpr(cmpEx) assert(solverContext.sat()) rewriter.pop() - solverContext.assertGroundExpr(OperEx(TlaBoolOper.not, cmpEx)) + solverContext.assertGroundExpr(not(cmpEx ? "b").typed(intTypes, "b")) assert(!solverContext.sat()) case _ => @@ -190,28 +200,29 @@ class TestSymbStateRewriterInt extends RewriterBase { } } - test("SE-INT-CELL-CMP1: $C$_i: Int >= $C$_j: Int ~~> valInt(...) >= valInt(...)") { + test("$C$_i: Int >= $C$_j: Int ~~> valInt(...) >= valInt(...)") { arena = arena.appendCell(IntT()) val leftCell = arena.topCell arena = arena.appendCell(IntT()) val rightCell = arena.topCell - val state = new SymbState(OperEx(TlaArithOper.ge, leftCell.toNameEx, rightCell.toNameEx), arena, Binding()) + val state = + new SymbState(ge(leftCell.toNameEx ? "i", rightCell.toNameEx ? "i").typed(intTypes, "b"), arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { case cmpEx @ NameEx(name) => assert(solverContext.sat()) solverContext.assertGroundExpr(cmpEx) - solverContext.assertGroundExpr(OperEx(TlaOper.eq, leftCell.toNameEx, ValEx(TlaInt(4)))) + solverContext.assertGroundExpr(eql(leftCell.toNameEx ? "i", int(4)).typed(intTypes, "b")) rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, rightCell.toNameEx, ValEx(TlaInt(22)))) + solverContext.assertGroundExpr(eql(rightCell.toNameEx ? "i", int(22)).typed(intTypes, "b")) assert(!solverContext.sat()) rewriter.pop() rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, rightCell.toNameEx, ValEx(TlaInt(4)))) + solverContext.assertGroundExpr(eql(rightCell.toNameEx ? "i", int(4)).typed(intTypes, "b")) assert(solverContext.sat()) rewriter.pop() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, rightCell.toNameEx, ValEx(TlaInt(3)))) + solverContext.assertGroundExpr(eql(rightCell.toNameEx ? "i", int(3)).typed(intTypes, "b")) assert(solverContext.sat()) case _ => @@ -219,32 +230,33 @@ class TestSymbStateRewriterInt extends RewriterBase { } } - test("SE-INT-CMP1: ~($Z$i = $Z$j) ~~> $B$k") { + test("~($Z$i = $Z$j) ~~> $B$k") { arena = arena.appendCell(IntT()) val leftInt = arena.topCell.toNameEx arena = arena.appendCell(IntT()) val rightInt = arena.topCell.toNameEx - val state = new SymbState(tla.not(tla.eql(leftInt, rightInt)), arena, Binding()) + val state = + new SymbState(not(eql(leftInt ? "i", rightInt ? "i") ? "b").typed(intTypes, "b"), arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { case predEx @ NameEx(name) => assert(solverContext.sat()) - solverContext.assertGroundExpr(OperEx(TlaOper.eq, leftInt, ValEx(TlaInt(22)))) + solverContext.assertGroundExpr(eql(leftInt ? "i", int(22)).typed(intTypes, "b")) rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, rightInt, ValEx(TlaInt(22)))) + solverContext.assertGroundExpr(eql(rightInt ? "i", int(22)).typed(intTypes, "b")) rewriter.push() solverContext.assertGroundExpr(predEx) assert(!solverContext.sat()) rewriter.pop() rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaBoolOper.not, predEx)) + solverContext.assertGroundExpr(not((predEx ? "b")).typed(intTypes, "b")) assert(solverContext.sat()) rewriter.pop() rewriter.pop() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, rightInt, ValEx(TlaInt(1981)))) + solverContext.assertGroundExpr(eql(rightInt ? "i", int(1981)).typed(intTypes, "b")) rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaBoolOper.not, predEx)) + solverContext.assertGroundExpr(not(predEx ? "b").typed(intTypes, "b")) assert(!solverContext.sat()) rewriter.pop() solverContext.assertGroundExpr(predEx) @@ -255,27 +267,27 @@ class TestSymbStateRewriterInt extends RewriterBase { } } - test("SE-INT-ARITH1[+]: $Z$i + $Z$j ~~> $Z$k") { + test("$Z$i + $Z$j ~~> $Z$k") { arena = arena.appendCell(IntT()) val leftInt = arena.topCell.toNameEx arena = arena.appendCell(IntT()) val rightInt = arena.topCell.toNameEx - val expr = OperEx(TlaArithOper.plus, leftInt, rightInt) + val expr = plus(leftInt ? "i", rightInt ? "i").typed(intTypes, "i") val state = new SymbState(expr, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { case result @ NameEx(name) => assert(solverContext.sat()) - solverContext.assertGroundExpr(OperEx(TlaOper.eq, leftInt, ValEx(TlaInt(1981)))) + solverContext.assertGroundExpr(eql(leftInt ? "i", int(1981)).typed(intTypes, "b")) rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, rightInt, ValEx(TlaInt(36)))) + solverContext.assertGroundExpr(eql(rightInt ? "i", int(36)).typed(intTypes, "b")) rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, result, ValEx(TlaInt(2017)))) + solverContext.assertGroundExpr(eql(result ? "i", int(2017)).typed(intTypes, "b")) assert(solverContext.sat()) rewriter.pop() rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, result, ValEx(TlaInt(2016)))) + solverContext.assertGroundExpr(eql(result ? "i", int(2016)).typed(intTypes, "b")) assert(!solverContext.sat()) case _ => @@ -283,27 +295,27 @@ class TestSymbStateRewriterInt extends RewriterBase { } } - test("SE-INT-ARITH1[-]: $Z$i - $Z$j ~~> $Z$k") { + test("$Z$i - $Z$j ~~> $Z$k") { arena = arena.appendCell(IntT()) val leftInt = arena.topCell.toNameEx arena = arena.appendCell(IntT()) val rightInt = arena.topCell.toNameEx - val expr = OperEx(TlaArithOper.minus, leftInt, rightInt) + val expr = minus(leftInt ? "i", rightInt ? "i").typed(intTypes, "i") val state = new SymbState(expr, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { case result @ NameEx(name) => assert(solverContext.sat()) - solverContext.assertGroundExpr(OperEx(TlaOper.eq, leftInt, ValEx(TlaInt(2017)))) + solverContext.assertGroundExpr(eql(leftInt ? "i", int(2017)).typed(intTypes, "b")) rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, rightInt, ValEx(TlaInt(36)))) + solverContext.assertGroundExpr(eql(rightInt ? "i", int(36)).typed(intTypes, "b")) rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, result, ValEx(TlaInt(1981)))) + solverContext.assertGroundExpr(eql(result ? "i", int(1981)).typed(intTypes, "b")) assert(solverContext.sat()) rewriter.pop() rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, result, ValEx(TlaInt(1980)))) + solverContext.assertGroundExpr(eql(result ? "i", int(1980)).typed(intTypes, "b")) assert(!solverContext.sat()) case _ => @@ -311,23 +323,23 @@ class TestSymbStateRewriterInt extends RewriterBase { } } - test("SE-INT-ARITH1[-.]: -$Z$j ~~> $Z$k") { + test("-$Z$j ~~> $Z$k") { arena = arena.appendCell(IntT()) val leftInt = arena.topCell.toNameEx - val expr = OperEx(TlaArithOper.uminus, leftInt) + val expr = uminus(leftInt ? "i").typed(intTypes, "i") val state = new SymbState(expr, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { case result @ NameEx(name) => assert(solverContext.sat()) - solverContext.assertGroundExpr(OperEx(TlaOper.eq, leftInt, ValEx(TlaInt(2017)))) + solverContext.assertGroundExpr(eql(leftInt ? "i", int(2017)).typed(intTypes, "b")) rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, result, ValEx(TlaInt(-2017)))) + solverContext.assertGroundExpr(eql(result ? "i", int(-2017)).typed(intTypes, "b")) assert(solverContext.sat()) rewriter.pop() rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, result, ValEx(TlaInt(2017)))) + solverContext.assertGroundExpr(eql(result ? "i", int(2017)).typed(intTypes, "b")) assert(!solverContext.sat()) case _ => @@ -335,27 +347,27 @@ class TestSymbStateRewriterInt extends RewriterBase { } } - test("SE-INT-ARITH1[*]: $Z$i * $Z$j ~~> $Z$k") { + test("$Z$i * $Z$j ~~> $Z$k") { arena = arena.appendCell(IntT()) val leftInt = arena.topCell.toNameEx arena = arena.appendCell(IntT()) val rightInt = arena.topCell.toNameEx - val expr = OperEx(TlaArithOper.mult, leftInt, rightInt) + val expr = mult(leftInt ? "i", rightInt ? "i").typed(intTypes, "i") val state = new SymbState(expr, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { case result @ NameEx(name) => assert(solverContext.sat()) - solverContext.assertGroundExpr(OperEx(TlaOper.eq, leftInt, ValEx(TlaInt(7)))) + solverContext.assertGroundExpr(eql(leftInt ? "i", int(7)).typed(intTypes, "b")) rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, rightInt, ValEx(TlaInt(4)))) + solverContext.assertGroundExpr(eql(rightInt ? "i", int(4)).typed(intTypes, "b")) rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, result, ValEx(TlaInt(28)))) + solverContext.assertGroundExpr(eql(result ? "i", int(28)).typed(intTypes, "b")) assert(solverContext.sat()) rewriter.pop() rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, result, ValEx(TlaInt(30)))) + solverContext.assertGroundExpr(eql(result ? "i", int(30)).typed(intTypes, "b")) assert(!solverContext.sat()) case _ => @@ -363,27 +375,27 @@ class TestSymbStateRewriterInt extends RewriterBase { } } - test("SE-INT-ARITH1[/]: $Z$i / $Z$j ~~> $Z$k") { + test("$Z$i / $Z$j ~~> $Z$k") { arena = arena.appendCell(IntT()) val leftInt = arena.topCell.toNameEx arena = arena.appendCell(IntT()) val rightInt = arena.topCell.toNameEx - val expr = OperEx(TlaArithOper.div, leftInt, rightInt) + val expr = div(leftInt ? "i", rightInt ? "i").typed(intTypes, "i") val state = new SymbState(expr, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { case result @ NameEx(name) => assert(solverContext.sat()) - solverContext.assertGroundExpr(OperEx(TlaOper.eq, leftInt, ValEx(TlaInt(30)))) + solverContext.assertGroundExpr(eql(leftInt ? "i", int(30)).typed(intTypes, "b")) rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, rightInt, ValEx(TlaInt(4)))) + solverContext.assertGroundExpr(eql(rightInt ? "i", int(4)).typed(intTypes, "b")) rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, result, ValEx(TlaInt(7)))) + solverContext.assertGroundExpr(eql(result ? "i", int(7)).typed(intTypes, "b")) assert(solverContext.sat()) rewriter.pop() rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, result, ValEx(TlaInt(8)))) + solverContext.assertGroundExpr(eql(result ? "i", int(8)).typed(intTypes, "b")) assert(!solverContext.sat()) case _ => @@ -391,27 +403,27 @@ class TestSymbStateRewriterInt extends RewriterBase { } } - test("SE-INT-ARITH1[%]: $Z$i % $Z$j ~~> $Z$k") { + test("$Z$i % $Z$j ~~> $Z$k") { arena = arena.appendCell(IntT()) val leftInt = arena.topCell.toNameEx arena = arena.appendCell(IntT()) val rightInt = arena.topCell.toNameEx - val expr = OperEx(TlaArithOper.mod, leftInt, rightInt) + val expr = mod(leftInt ? "i", rightInt ? "i").typed(intTypes, "i") val state = new SymbState(expr, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { case result @ NameEx(name) => assert(solverContext.sat()) - solverContext.assertGroundExpr(OperEx(TlaOper.eq, leftInt, ValEx(TlaInt(30)))) + solverContext.assertGroundExpr(eql(leftInt ? "i", int(30)).typed(intTypes, "b")) rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, rightInt, ValEx(TlaInt(7)))) + solverContext.assertGroundExpr(eql(rightInt ? "i", int(7)).typed(intTypes, "b")) rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, result, ValEx(TlaInt(2)))) + solverContext.assertGroundExpr(eql(result ? "i", int(2)).typed(intTypes, "b")) assert(solverContext.sat()) rewriter.pop() rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaOper.eq, result, ValEx(TlaInt(1)))) + solverContext.assertGroundExpr(eql(result ? "i", int(1)).typed(intTypes, "b")) assert(!solverContext.sat()) case _ => @@ -419,12 +431,10 @@ class TestSymbStateRewriterInt extends RewriterBase { } } - test("""SE-INT-RNG: 2..5 = {2, 3, 4, 5}""") { - def mkSet(elems: TlaEx*) = OperEx(TlaSetOper.enumSet, elems: _*) - - val expected = mkSet(Range(2, 6).map(i => ValEx(TlaInt(i))): _*) - val range = OperEx(TlaArithOper.dotdot, ValEx(TlaInt(2)), ValEx(TlaInt(5))) - val eqExpected = OperEx(TlaOper.eq, range, expected) + test("""2..5 = {2, 3, 4, 5}""") { + val expected = enumSet(2.until(6).map(int): _*).typed(SetT1(IntT1())) + val range = dotdot(int(2), int(5)).typed(SetT1(IntT1())) + val eqExpected = eql(range, expected).typed(BoolT1()) val state = new SymbState(eqExpected, arena, Binding()) val rewriter = create() @@ -437,7 +447,7 @@ class TestSymbStateRewriterInt extends RewriterBase { solverContext.assertGroundExpr(predEx) assert(solverContext.sat()) rewriter.pop() - solverContext.assertGroundExpr(OperEx(TlaBoolOper.not, predEx)) + solverContext.assertGroundExpr(not(predEx ? "b").typed(intTypes, "b")) assert(!solverContext.sat()) case _ => @@ -446,9 +456,10 @@ class TestSymbStateRewriterInt extends RewriterBase { } test("""SE-INT-RNG: 2..(6 - 1) = {2, 3, 4, 5}""") { - val expected = tla.enumSet(2.to(5).map(i => tla.int(i)): _*).untyped() - val range = tla.dotdot(tla.int(2), tla.minus(tla.int(6), tla.int(1))).untyped() - val eqExpected = tla.eql(range, expected).untyped() + val expected = enumSet(2.to(5).map(int): _*).typed(SetT1(IntT1())) + val range = dotdot(int(2), minus(int(6), int(1)) ? "i") + .typed(intTypes, "I") + val eqExpected = eql(range, expected).typed(BoolT1()) val state = new SymbState(eqExpected, arena, Binding()) val rewriter = create() @@ -461,7 +472,7 @@ class TestSymbStateRewriterInt extends RewriterBase { solverContext.assertGroundExpr(predEx) assert(solverContext.sat()) rewriter.pop() - val notPred = tla.not(predEx).untyped() + val notPred = not(predEx ? "b").typed(intTypes, "b") solverContext.assertGroundExpr(notPred) assert(!solverContext.sat()) diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterPowerset.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterPowerset.scala index dd43db6903..531974ffe4 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterPowerset.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterPowerset.scala @@ -1,23 +1,30 @@ package at.forsyte.apalache.tla.bmcmt import at.forsyte.apalache.tla.bmcmt.rules.aux.PowSetCtor -import at.forsyte.apalache.tla.bmcmt.types.{AnnotationParser, FinSetT, IntT, PowSetT} -import at.forsyte.apalache.tla.lir.{NameEx, OperEx} -import at.forsyte.apalache.tla.lir.convenience.tla -import at.forsyte.apalache.tla.lir.oper.BmcOper -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.bmcmt.types.{FinSetT, IntT, PowSetT} +import at.forsyte.apalache.tla.lir.TypedPredefs._ +import at.forsyte.apalache.tla.lir.convenience.tla._ +import at.forsyte.apalache.tla.lir.{BoolT1, IntT1, NameEx, SetT1} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestSymbStateRewriterPowerset extends RewriterBase { - test("""SE-SUBSET1: SUBSET {1, 2, 3} ~~> c_set""") { - val ex = tla.powSet(tla.enumSet(tla.int(1), tla.int(2), tla.int(3))) + private val types = Map( + "i" -> IntT1(), + "I" -> SetT1(IntT1()), + "II" -> SetT1(SetT1(IntT1())), + "b" -> BoolT1() + ) + + test("""SUBSET {1, 2, 3}""") { + val ex = powSet(enumSet(int(1), int(2), int(3)) ? "I") + .typed(types, "II") val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { - case NameEx(name) => + case NameEx(_) => val cell = nextState.arena.findCellByNameEx(nextState.ex) assert(cell.cellType == PowSetT(FinSetT(IntT()))) val dom = nextState.arena.getDom(cell) @@ -32,114 +39,71 @@ class TestSymbStateRewriterPowerset extends RewriterBase { } test("""SE-SUBSET1: {1, 2} \in SUBSET {1, 2, 3}""") { - val set12 = tla.enumSet(tla.int(1), tla.int(2)) - val powset = tla.powSet(tla.enumSet(tla.int(1), tla.int(2), tla.int(3))) - val in = tla.in(set12, powset) - val state = new SymbState(in, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - assert(solverContext.sat()) - rewriter.push() - solverContext.assertGroundExpr(nextState.ex) - assert(solverContext.sat()) - rewriter.pop() - rewriter.push() - solverContext.assertGroundExpr(tla.not(nextState.ex)) - assertUnsatOrExplain(rewriter, nextState) + val set12 = enumSet(int(1), int(2)) ? "I" + val powset = powSet(enumSet(int(1), int(2), int(3)) ? "I") ? "II" + val inEx = in(set12, powset) + .typed(types, "b") + val state = new SymbState(inEx, arena, Binding()) + assertTlaExAndRestore(create(), state) } test("""SE-SUBSET1: {} \in SUBSET {1, 2, 3}""") { // an empty set requires a type annotation - val set12 = tla.withType(tla.enumSet(), AnnotationParser.toTla(FinSetT(IntT()))) - val powset = tla.powSet(tla.enumSet(tla.int(1), tla.int(2), tla.int(3))) - val in = tla.in(set12, powset) - val state = new SymbState(in, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - assert(solverContext.sat()) - rewriter.push() - solverContext.assertGroundExpr(nextState.ex) - assert(solverContext.sat()) - rewriter.pop() - rewriter.push() - solverContext.assertGroundExpr(tla.not(nextState.ex)) - assertUnsatOrExplain(rewriter, nextState) + val emptySet = enumSet() + .typed(types, "I") + val powset = powSet(enumSet(int(1), int(2), int(3)) ? "I") ? "II" + val inEx = in(emptySet, powset) + .typed(types, "b") + val state = new SymbState(inEx, arena, Binding()) + assertTlaExAndRestore(create(), state) } test("""SE-SUBSET1: {1, 2, 3} \in SUBSET {1, 2, 3}""") { - val set1to3 = tla.enumSet(tla.int(1), tla.int(2), tla.int(3)) - val powset = tla.powSet(set1to3) - val in = tla.in(set1to3, powset) - val state = new SymbState(in, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - assert(solverContext.sat()) - rewriter.push() - solverContext.assertGroundExpr(nextState.ex) - assert(solverContext.sat()) - rewriter.pop() - rewriter.push() - solverContext.assertGroundExpr(tla.not(nextState.ex)) - assertUnsatOrExplain(rewriter, nextState) + val set1to3 = enumSet(int(1), int(2), int(3)) ? "I" + val powset = powSet(set1to3) ? "II" + val inEx = in(set1to3, powset) + .typed(types, "b") + val state = new SymbState(inEx, arena, Binding()) + assertTlaExAndRestore(create(), state) } test("""SE-SUBSET1: {1, 2, 3, 4} \in SUBSET {1, 2, 3}""") { - def setTo(k: Int) = tla.enumSet(1 to k map tla.int: _*) - - val set1to4 = setTo(4) - val powset = tla.powSet(setTo(3)) - val in = tla.in(set1to4, powset) - val state = new SymbState(in, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - assert(solverContext.sat()) - rewriter.push() - solverContext.assertGroundExpr(nextState.ex) - assertUnsatOrExplain(rewriter, nextState) - rewriter.pop() - rewriter.push() - solverContext.assertGroundExpr(tla.not(nextState.ex)) - assert(solverContext.sat()) + def setTo(k: Int) = enumSet(1 to k map int: _*) + + val set1to4 = setTo(4) ? "I" + val powset = powSet(setTo(3) ? "I") ? "II" + val inEx = not(in(set1to4, powset) ? "b") + .typed(types, "b") + val state = new SymbState(inEx, arena, Binding()) + assertTlaExAndRestore(create(), state) } test("""SE-SUBSET: \E X \in SUBSET {1, 2}: TRUE (sat)""") { // a regression test that failed in the previous versions - val set = tla.enumSet(tla.int(1), tla.int(2)) - val ex = tla.exists(tla.name("X"), tla.powSet(set), tla.bool(true)) + val set = enumSet(int(1), int(2)) ? "I" + val ex = exists(name("X") ? "I", powSet(set) ? "II", bool(true)) + .typed(types, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() try { - val nextState = rewriter.rewriteUntilDone(state) + val _ = rewriter.rewriteUntilDone(state) fail("expected an error message about unfolding a powerset") } catch { case _: UnsupportedOperationException => () // OK } - // nextState.ex match { - // case predEx@NameEx(name) => - // assert(BoolTheory().hasConst(name)) - // rewriter.push() - // solverContext.assertGroundExpr(predEx) - // assert(solverContext.sat()) - // rewriter.pop() - // rewriter.push() - // solverContext.assertGroundExpr(tla.not(predEx)) - // assertUnsatOrExplain(rewriter, nextState) - // - // case _ => - // fail("Unexpected rewriting result") - // } } test("""SE-SUBSET: Skolem(\E X \in SUBSET {1, 2}: TRUE) (sat)""") { // a regression test that failed in the previous versions - val set = tla.enumSet(tla.int(1), tla.int(2)) + val set = enumSet(int(1), int(2)) ? "I" val ex = - OperEx(BmcOper.skolem, tla.exists(tla.name("X"), tla.powSet(set), tla.bool(true))) + apalacheSkolem(exists(name("X") ? "I", powSet(set) ? "II", bool(true)) ? "b") + .typed(types, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { - case predEx @ NameEx(name) => + case predEx @ NameEx(_) => rewriter.push() solverContext.assertGroundExpr(predEx) assert(solverContext.sat()) @@ -151,15 +115,16 @@ class TestSymbStateRewriterPowerset extends RewriterBase { test("""SE-SUBSET: Skolem(\E X \in SUBSET {1, 2}: FALSE (unsat))""") { // a regression test that failed in the previous versions - val set = tla.enumSet(tla.int(1), tla.int(2)) + val set = enumSet(int(1), int(2)) ? "I" val ex = - OperEx(BmcOper.skolem, tla.exists(tla.name("X"), tla.powSet(set), tla.bool(false))) + apalacheSkolem(exists(name("X") ? "I", powSet(set) ? "II", bool(false)) ? "b") + .typed(types, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { - case predEx @ NameEx(name) => + case predEx @ NameEx(_) => rewriter.push() solverContext.assertGroundExpr(predEx) assertUnsatOrExplain(rewriter, nextState) @@ -170,19 +135,18 @@ class TestSymbStateRewriterPowerset extends RewriterBase { } test("""PowSetCtor {1, 2}""") { - val baseset = tla.enumSet(tla.int(1), tla.int(2)) + val baseset = enumSet(int(1), int(2)) + .typed(types, "I") val state = new SymbState(baseset, arena, Binding()) val rewriter = create() var nextState = rewriter.rewriteUntilDone(state) val baseCell = nextState.asCell nextState = new PowSetCtor(rewriter).confringo(nextState, baseCell) val powCell = nextState.asCell - // give the cell type to type finder - rewriter.typeFinder.reset(rewriter.typeFinder.varTypes + (powCell.toString -> powCell.cellType)) // check equality - val eq = tla.eql(nextState.ex, - tla.enumSet(tla.withType(tla.enumSet(), AnnotationParser.toTla(FinSetT(IntT()))), tla.enumSet(tla.int(1)), - tla.enumSet(tla.int(2)), tla.enumSet(tla.int(1), tla.int(2)))) + val eq = eql(nextState.ex, + enumSet(enumSet() ? "I", enumSet(int(1)) ? "I", enumSet(int(2)) ? "I", enumSet(int(1), int(2)) ? "I") ? "II") + .typed(types, "b") assertTlaExAndRestore(rewriter, nextState.setRex(eq)) } diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterRecFun.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterRecFun.scala index be7bc8008f..2ca54fbba4 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterRecFun.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterRecFun.scala @@ -1,34 +1,42 @@ package at.forsyte.apalache.tla.bmcmt +import at.forsyte.apalache.tla.lir.TypedPredefs._ import at.forsyte.apalache.tla.lir._ -import at.forsyte.apalache.tla.lir.convenience.tla -import at.forsyte.apalache.tla.lir.values.TlaIntSet -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.lir.convenience.tla._ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestSymbStateRewriterRecFun extends RewriterBase with TestingPredefs { - test("""recursive fun: f[n \in { 1, 2, 3 }] == IF n <= 1 THEN 2 ELSE 2 * f[n - 1]""") { - import tla._ - - val set = enumSet(tla.int(1), tla.int(2), tla.int(3)) + private val types = + Map( + "b" -> BoolT1(), + "i" -> IntT1(), + "I" -> SetT1(IntT1()), + "i_to_i" -> FunT1(IntT1(), IntT1()) + ) - val ref = tla.withType(tla.recFunRef(), tla.funSet(ValEx(TlaIntSet), ValEx(TlaIntSet))) + test("""recursive fun: f[n \in { 1, 2, 3 }] == IF n <= 1 THEN 2 ELSE 2 * f[n - 1]""") { + val set = enumSet(int(1), int(2), int(3)) ? "I" + val ref = recFunRef() ? "i_to_i" val map = ite( - le(tla.name("n"), int(1)), + le(name("n") ? "i", int(1)) ? "b", int(2), - mult(int(2), appFun(ref, minus(tla.name("n"), int(1)))) - ) /// + mult(int(2), appFun(ref, minus(name("n") ? "i", int(1)) ? "i") ? "i") ? "i" + ).typed(types, "i") - val fun = recFunDef(map, tla.name("n"), set) + val fun = recFunDef(map, name("n") ? "i", set) + .typed(types, "i_to_i") val rewriter = create() var state = rewriter.rewriteUntilDone(new SymbState(fun, arena, Binding())) val funCell = state.ex - def resEq(i: Int, j: Int) = eql(int(j), appFun(funCell, int(i))) + def resEq(i: Int, j: Int) = { + eql(int(j), appFun(funCell ? "i_to_i", int(i)) ? "i") + .typed(types, "b") + } assertTlaExAndRestore(rewriter, state.setRex(resEq(1, 2))) assertTlaExAndRestore(rewriter, state.setRex(resEq(2, 4))) diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterRecord.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterRecord.scala index 794380bee8..096af65b72 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterRecord.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterRecord.scala @@ -1,9 +1,9 @@ package at.forsyte.apalache.tla.bmcmt -import at.forsyte.apalache.tla.bmcmt.types._ -import at.forsyte.apalache.tla.lir.NameEx -import at.forsyte.apalache.tla.lir.convenience.tla -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.bmcmt.types.{BoolT, ConstT, FinSetT, IntT, RecordT} +import at.forsyte.apalache.tla.lir.{BoolT1, IntT1, NameEx, RecT1, SetT1, StrT1, TupT1} +import at.forsyte.apalache.tla.lir.convenience.tla._ +import at.forsyte.apalache.tla.lir.TypedPredefs._ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @@ -11,24 +11,43 @@ import scala.collection.immutable.{SortedMap, SortedSet, TreeMap} @RunWith(classOf[JUnitRunner]) class TestSymbStateRewriterRecord extends RewriterBase { - test("""RecordDomainCache: ~(dom {"a", "b"} = dom {"a", "b", "c"}) ~~> $C$k""") { + private val types = Map( + "b" -> BoolT1(), + "i" -> IntT1(), + "I" -> SetT1(IntT1()), + "s" -> StrT1(), + "(s)" -> TupT1(StrT1()), + "S" -> SetT1(StrT1()), + "rib" -> RecT1("a" -> IntT1(), "b" -> BoolT1()), + "RIB" -> SetT1(RecT1("a" -> IntT1(), "b" -> BoolT1())), + "ribs" -> RecT1("a" -> IntT1(), "b" -> BoolT1(), "c" -> StrT1()), + "RIBS" -> SetT1(RecT1("a" -> IntT1(), "b" -> BoolT1(), "c" -> StrT1())), + "rib" -> RecT1("a" -> IntT1(), "b" -> BoolT1()), + "rii" -> RecT1("a" -> IntT1(), "c" -> IntT1()), + "RII" -> SetT1(RecT1("a" -> IntT1(), "c" -> IntT1())) + ) + + test("""RecordDomainCache: ~(dom {"a", "b"} = dom {"a", "b", "c"})""") { val rewriter = create() val (newArena1, set1) = rewriter.recordDomainCache.create(arena, (SortedSet("a", "b"), SortedSet[String]())) val (newArena2, set2) = rewriter.recordDomainCache.create(newArena1, (SortedSet("a", "b", "c"), SortedSet[String]())) - val neq = tla.not(tla.eql(set1.toNameEx, set2.toNameEx)) + // the domains should not be equal + val neq = not(eql(set1.toNameEx, set2.toNameEx) ? "b") + .typed(types, "b") val state = new SymbState(neq, newArena2, Binding()) assertTlaExAndRestore(rewriter, state) } - test("""SE-REC-CTOR[1-2]: ["a" |-> 1, "b" |-> FALSE, "c" |-> "d"] ~~> $C$k""") { - val record = tla.enumFun(tla.str("a"), tla.int(1), tla.str("b"), tla.bool(false), tla.str("c"), tla.str("d")) + test("""["a" |-> 1, "b" |-> FALSE, "c" |-> "d"]""") { + val record = enumFun(str("a"), int(1), str("b"), bool(false), str("c"), str("d")) + .typed(types, "ribs") val state = new SymbState(record, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { - case membershipEx @ NameEx(name) => + case _ @NameEx(name) => assert(solverContext.sat()) val cell = nextState.arena.findCellByName(name) cell.cellType match { @@ -44,7 +63,8 @@ class TestSymbStateRewriterRecord extends RewriterBase { // also make sure that the domain equality works val (newArena, expectedDom) = rewriter.recordDomainCache.getOrCreate(nextState.arena, (SortedSet("a", "b", "c"), SortedSet[String]())) - val eq = tla.eql(expectedDom.toNameEx, tla.dom(cell.toNameEx)) + val eq = eql(expectedDom.toNameEx ? "S", dom(cell.toNameEx) ? "S") + .typed(types, "b") assertTlaExAndRestore(rewriter, nextState.setArena(newArena).setRex(eq)) // we check the actual contents in the later tests that access elements @@ -58,57 +78,46 @@ class TestSymbStateRewriterRecord extends RewriterBase { } } - test("""SE-REC-ACC[1-2]: ["a" |-> 1, "b" |-> FALSE, "c" |-> "d"]["c"] ~~> $C$k equals \"d\"""") { - val record = tla.enumFun(tla.str("a"), tla.int(1), tla.str("b"), tla.bool(false), tla.str("c"), tla.str("d")) + test("""["a" |-> 1, "b" |-> FALSE, "c" |-> "d"]["c"] equals "d" """) { + val record = enumFun(str("a"), int(1), str("b"), bool(false), str("c"), str("d")) + val recordAcc = appFun(record ? "ribs", str("b") ? "s") + val eqD = eql(recordAcc ? "b", bool(false)) + .typed(types, "b") - val recordAcc = tla.appFun(record, tla.str("b")) - val state = new SymbState(recordAcc, arena, Binding()) + val state = new SymbState(eqD, arena, Binding()) val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case membershipEx @ NameEx(name) => - assert(solverContext.sat()) - val cell = nextState.arena.findCellByName(name) - cell.cellType match { - case BoolT() => - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(cell.toNameEx, tla.bool(false)))) - - // we check the actual contents in the later tests that access elements - - case _ => - fail("Expected Boolean type") - } - - case _ => - fail("Unexpected rewriting result") - } + assertTlaExAndRestore(rewriter, state.setRex(eqD)) } - test("""type inference error ["a" |-> 1, "b" |-> FALSE]["c"]""") { - val record = tla.enumFun(tla.str("a"), tla.int(1), tla.str("b"), tla.bool(false)) + test("""accessing a non-existing field: ["a" |-> 1, "b" |-> FALSE]["c"]""") { + val record = enumFun(str("a"), int(1), str("b"), bool(false)) + // We assume that record has the type RecT1("a" -> IntT1(), "b" -> BoolT1(), "c" -> StrT1()). + // This can happen due to type unification. The record access should still work, + // though the access is expected to produce an arbitrary value (of proper type). + val recordAcc = appFun(record ? "ribs", str("c")) + .typed(types, "s") - val recordAcc = tla.appFun(record, tla.str("c")) val state = new SymbState(recordAcc, arena, Binding()) val rewriter = create() - assertThrows[TypeInferenceException] { - rewriter.rewriteUntilDone(state) - } + rewriter.rewriteUntilDone(state) + assert(solverContext.sat()) } - test("""SE-REC-CTOR[1-2] in a set: {["a" |-> 1, "b" |-> FALSE], ["a" |-> 2, "b" |-> TRUE]} ~~> $C$k""") { - val record1 = tla.enumFun(tla.str("a"), tla.int(1), tla.str("b"), tla.bool(false)) - val record2 = tla.enumFun(tla.str("a"), tla.int(2), tla.str("b"), tla.bool(true)) - - val state = new SymbState(tla.enumSet(record1, record2), arena, Binding()) + test("""{["a" |-> 1, "b" |-> FALSE], ["a" |-> 2, "b" |-> TRUE]}""") { + val record1 = enumFun(str("a"), int(1), str("b"), bool(false)) + val record2 = enumFun(str("a"), int(2), str("b"), bool(true)) + val set = enumSet(record1 ? "rib", record2 ? "rib") + .typed(types, "RIB") + val state = new SymbState(set, arena, Binding()) val nextState = create().rewriteUntilDone(state) + nextState.ex match { - case membershipEx @ NameEx(name) => + case NameEx(name) => assert(solverContext.sat()) val cell = nextState.arena.findCellByName(name) cell.cellType match { case FinSetT(rt @ RecordT(_)) => assert(rt.fields == TreeMap("a" -> IntT(), "b" -> BoolT())) - // we check the actual contents in the later tests that access elements case _ => @@ -120,22 +129,25 @@ class TestSymbStateRewriterRecord extends RewriterBase { } } - test("""SE-REC-CTOR[1-2] in a set: {["a" |-> 1, "b" |-> FALSE], ["a" |-> 2, "b" |-> TRUE, "c" |-> 3]} ~~> $C$k""") { - val record1 = tla.enumFun(tla.str("a"), tla.int(1), tla.str("b"), tla.bool(false)) - val record2 = tla.enumFun(tla.str("a"), tla.int(2), tla.str("b"), tla.bool(true), tla.str("c"), tla.int(3)) - // Records in a set can have different sets of keys. This requires a type annotation. - val annotation = AnnotationParser.toTla(RecordT(SortedMap("a" -> IntT(), "b" -> BoolT(), "c" -> IntT()))) - - val state = new SymbState(tla.enumSet(tla.withType(record1, annotation), record2), arena, Binding()) + test("""{["a" |-> 1, "b" |-> FALSE], ["a" |-> 2, "b" |-> TRUE, "c" |-> "foo"]}""") { + // Although record1 has two fields we provide the type `ribs`. This is how the type checker does type unification. + val record1 = enumFun(str("a"), int(1), str("b"), bool(false)) + .typed(types, "ribs") + val record2 = enumFun(str("a"), int(2), str("b"), bool(true), str("c"), str("foo")) + .typed(types, "ribs") + val recSet = enumSet(record1, record2) + .typed(types, "RIBS") + val state = new SymbState(recSet, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) + nextState.ex match { - case membershipEx @ NameEx(name) => + case NameEx(name) => assert(solverContext.sat()) val cell = nextState.arena.findCellByName(name) cell.cellType match { case FinSetT(rt @ RecordT(_)) => - assert(rt.fields == TreeMap("a" -> IntT(), "b" -> BoolT(), "c" -> IntT())) + assert(rt.fields == TreeMap("a" -> IntT(), "b" -> BoolT(), "c" -> ConstT())) case _ => fail("Unexpected type: " + cell.cellType) @@ -146,150 +158,87 @@ class TestSymbStateRewriterRecord extends RewriterBase { } } - test("""SE-REC-CTOR[1-2] type error: {["a" |-> FALSE, "b" |-> 1], ["a" |-> 2, "b" |-> TRUE]} ~~> $C$k""") { - val record1 = tla.enumFun(tla.str("a"), tla.bool(false), tla.str("b"), tla.int(1)) - val record2 = tla.enumFun(tla.str("a"), tla.int(2), tla.str("b"), tla.bool(true)) - - val state = new SymbState(tla.enumSet(record1, record2), arena, Binding()) - // this is a badly-typed expression - assertThrows[TypeInferenceException] { - create().rewriteUntilDone(state) - } - } - - test( - """filter-map a record (idiom): {r.c : r \in {r2 \in {["a" |-> 1], ["a" |-> 2, "c" |-> 3]}: r2.c = 3}} ~~> $C$k""") { + test("""filter-map a record (idiom): {r.c : r \in {r2 \in {["a" |-> 1], ["a" |-> 2, "c" |-> 3]}: r2.c = 3}}""") { // It is a common idiom in TLA+ to first filter records by the type field // and then -- when knowing the type of the filtered records -- map them somewhere. // Although, it is not easy to do in a symbolic encoding, we support this idiom. - // We require though that all the records should have type-compatible fields. - val record1 = tla.enumFun(tla.str("a"), tla.int(1)) - val record2 = tla.enumFun(tla.str("a"), tla.int(2), tla.str("c"), tla.int(3)) + // We require though that all the records have type-compatible fields. + val record1 = enumFun(str("a"), int(1)) + .typed(types, "rii") + val record2 = enumFun(str("a"), int(2), str("c"), int(3)) + .typed(types, "rii") // Records in a set can have different sets of keys. This requires a type annotation. - val annotation = AnnotationParser.toTla(RecordT(SortedMap("a" -> IntT(), "c" -> IntT()))) - val setEx = tla.enumSet(tla.withType(record1, annotation), record2) - val predEx = tla.eql(tla.appFun(tla.name("r2"), tla.str("c")), tla.int(3)) - val filteredEx = tla.filter(tla.name("r2"), setEx, predEx) - val mapEx = tla.map(tla.appFun(tla.name("r"), tla.str("c")), tla.name("r"), filteredEx) + val setEx = enumSet(record1, record2) + .typed(types, "RII") + val predEx = eql(appFun(name("r2") ? "rii", str("c")) ? "i", int(3)) + .typed(types, "b") + val filteredEx = filter(name("r2") ? "rii", setEx, predEx) + .typed(types, "RII") + val mapEx = map(appFun(name("r") ? "rii", str("c")) ? "i", name("r") ? "rii", filteredEx) + .typed(types, "I") + + val eq = eql(mapEx, enumSet(int(3)) ? "I") + .typed(types, "b") val state = new SymbState(mapEx, arena, Binding()) val rewriter = create() - rewriter.push() - val nextState = rewriter.rewriteUntilDone(state) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(nextState.ex, tla.enumSet(tla.int(3))))) - rewriter.pop() - rewriter.push() - val filteredState = rewriter.rewriteUntilDone(state.setRex(filteredEx)) - val cell = nextState.arena.findCellByNameEx(filteredState.ex) - assert(cell.cellType == FinSetT(RecordT(SortedMap("a" -> IntT(), "c" -> IntT())))) + assertTlaExAndRestore(rewriter, state.setRex(eq)) } - test("""SE-REC-EQ: [a |-> 1, b |-> FALSE, c |-> "d"] = [c |-> "d", b |-> FALSE, a |-> 1] ~~> TRUE""") { - val record1 = tla.enumFun(tla.str("a"), tla.int(1), tla.str("b"), tla.bool(false), tla.str("c"), tla.str("d")) - val record2 = tla.enumFun(tla.str("c"), tla.str("d"), tla.str("b"), tla.bool(false), tla.str("a"), tla.int(1)) - val eq = tla.eql(record1, record2) + test("""[a |-> 1, b |-> FALSE, c |-> "d"] = [c |-> "d", b |-> FALSE, a |-> 1]""") { + // order of the fields does not matter + val record1 = enumFun(str("a"), int(1), str("b"), bool(false), str("c"), str("d")) + val record2 = enumFun(str("c"), str("d"), str("b"), bool(false), str("a"), int(1)) + val eq = eql(record1 ? "ribs", record2 ? "ribs") + .typed(types, "b") val state = new SymbState(eq, arena, Binding()) val rewriter = create() assertTlaExAndRestore(rewriter, state) } - test("""SE-REC-EQ: ~([a |-> 1, b |-> FALSE, c |-> "d"] = [a |-> 1]) ~~> TRUE""") { - // Introduce two different records using a type annotation. The records should not be equal! - val annotation = AnnotationParser.toTla(RecordT(SortedMap("a" -> IntT(), "b" -> BoolT(), "c" -> ConstT()))) - - val record1 = tla.enumFun(tla.str("a"), tla.int(1), tla.str("b"), tla.bool(false), tla.str("c"), tla.str("d")) - val record2 = tla.enumFun(tla.str("a"), tla.int(1)) - val eq = tla.not(tla.eql(record1, tla.withType(record2, annotation))) + test("""~([a |-> 1, b |-> FALSE, c |-> "d"] = [a |-> 1]) equals TRUE""") { + val record1 = enumFun(str("a"), int(1), str("b"), bool(false), str("c"), str("d")) + val record2 = enumFun(str("a"), int(1)) + val eq = not(eql(record1 ? "ribs", record2 ? "ribs") ? "b") + .typed(types, "b") val state = new SymbState(eq, arena, Binding()) val rewriter = create() assertTlaExAndRestore(rewriter, state) } - // Keramelizer does this expansion - ignore( - """SE-REC-SET: {[n |-> 1, b |-> FALSE], [n |-> 2, b |-> FALSE], [n |-> 1, b |-> TRUE], [n |-> 2, b |-> TRUE] = {[n : {1, 2}, b : {FALSE, TRUE}}""".stripMargin) { - val set12 = tla.enumSet(1 to 2 map tla.int: _*) - val setBool = tla.enumSet(tla.bool(false), tla.bool(true)) - val prod = tla.recSet(tla.str("n"), set12, tla.str("b"), setBool) - def rec(i: Int, b: Boolean) = - tla.enumFun(tla.str("n"), tla.int(i), tla.str("b"), tla.bool(b)) - val eq = tla.eql(prod, tla.enumSet(rec(1, false), rec(1, true), rec(2, false), rec(2, true))) - - val state = new SymbState(eq, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - rewriter.push() - solverContext.assertGroundExpr(nextState.ex) - assert(solverContext.sat()) - rewriter.pop() - solverContext.assertGroundExpr(tla.not(nextState.ex)) - assert(!solverContext.sat()) - } - - // Keramelizer does this expansion - ignore("""SE-REC-SET: {[n : {1, 2}} <: {[n |-> Int, b |-> BOOLEAN ]}""".stripMargin) { - val set12 = tla.enumSet(1 to 2 map tla.int: _*) - val setBool = tla.enumSet(tla.bool(false), tla.bool(true)) - val prod = tla.recSet(tla.str("n"), set12) - val expectedRecordT = FinSetT(RecordT(SortedMap("n" -> IntT(), "b" -> BoolT()))) - val annotated = tla.withType(prod, AnnotationParser.toTla(expectedRecordT)) - - val state = new SymbState(annotated, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - val cell = nextState.arena.findCellByNameEx(nextState.ex) - assert(expectedRecordT == cell.cellType) - } - - // Keramelizer and ExprOptimizer rewrite assignments over records sets into existentials over records - ignore("""SE-REC-SET: x' \in {[n |-> Int, b |-> BOOLEAN ]}""".stripMargin) { - val set12 = tla.enumSet(1 to 2 map tla.int: _*) - val setBool = tla.enumSet(tla.bool(false), tla.bool(true)) - val prod = tla.recSet(tla.str("n"), set12, tla.str("b"), setBool) - val assign = tla.in(tla.prime(tla.name("x")), prod) - - val state = new SymbState(assign, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - val inInt = - tla.in(tla.appFun(tla.prime(tla.name("x")), tla.str("n")), tla.enumSet(tla.int(1), tla.int(2))) - assertTlaExAndRestore(rewriter, nextState.setRex(inInt)) - - val inBool = - tla.in(tla.appFun(tla.prime(tla.name("x")), tla.str("b")), tla.enumSet(tla.bool(false), tla.bool(true))) - assertTlaExAndRestore(rewriter, nextState.setRex(inBool)) - } - - test("""SE-REC-DOM: DOMAIN [a |-> 1, b |-> FALSE, c |-> "d"] = {"a", "b", "c"}""") { - val record = tla.enumFun(tla.str("a"), tla.int(1), tla.str("b"), tla.bool(false), tla.str("c"), tla.str("d")) - val dom = tla.dom(record) - val eq = tla.eql(dom, tla.enumSet(tla.str("a"), tla.str("b"), tla.str("c"))) + test("""DOMAIN [a |-> 1, b |-> FALSE, c |-> "d"] equals {"a", "b", "c"}""") { + // the domain of a record stays the same, even if it is lifted to a more general record type + val record = enumFun(str("a"), int(1), str("b"), bool(false), str("c"), str("d")) + val domain = dom(record ? "ribs") + val eq = eql(domain ? "S", enumSet(str("a"), str("b"), str("c")) ? "S") + .typed(types, "b") val state = new SymbState(eq, arena, Binding()) val rewriter = create() assertTlaExAndRestore(rewriter, state) } - test("""SE-REC-DOM: DOMAIN ([a |-> 1] <: [a |-> 1, b |-> FALSE, c |-> "d"]) = {"a", "b", "c"}""") { - val record = tla.enumFun(tla.str("a"), tla.int(1), tla.str("b"), tla.bool(false), tla.str("c"), tla.str("d")) - val richerType = AnnotationParser.toTla(RecordT(SortedMap("a" -> IntT(), "b" -> BoolT(), "c" -> ConstT()))) - val annotated = - tla.withType(tla.enumFun(tla.str("a"), tla.int(1)), richerType) - val dom = tla.dom(annotated) - val eq = tla.eql(dom, tla.enumSet(tla.str("a"))) + test("""DOMAIN [a |-> 1] = {"a"} under type annotations!""") { + val record = enumFun(str("a"), int(1)) + .typed(types, "ribs") + val domain = dom(record) + val eq = eql(domain ? "S", enumSet(str("a")) ? "S") + .typed(types, "b") val state = new SymbState(eq, arena, Binding()) val rewriter = create() assertTlaExAndRestore(rewriter, state) } - test("""SE-REC-EXCEPT:[ ["a" |-> 1, "b" |-> FALSE] EXCEPT !["a"] = 3 ]""") { - val record = tla.enumFun(tla.str("a"), tla.int(1), tla.str("b"), tla.bool(false)) - val recExcept = tla.except(record, tla.tuple(tla.str("a")), tla.int(3)) + test("""[ ["a" |-> 1, "b" |-> FALSE] EXCEPT !["a"] = 3 ]""") { + val record = enumFun(str("a"), int(1), str("b"), bool(false)) + val updatedRec = except(record ? "rib", tuple(str("a")) ? "(s)", int(3)) + .typed(types, "rib") + val expectedRec = enumFun(str("a"), int(3), str("b"), bool(false)) + .typed(types, "rib") + val eq = eql(expectedRec, updatedRec) + .typed(types, "b") - val state = new SymbState(recExcept, arena, Binding()) + val state = new SymbState(eq, arena, Binding()) val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - val expectedRec = tla.enumFun(tla.str("a"), tla.int(3), tla.str("b"), tla.bool(false)) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(expectedRec, nextState.ex))) + assertTlaExAndRestore(rewriter, state) } - } diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterSequence.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterSequence.scala index f934a60739..3d039cf894 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterSequence.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterSequence.scala @@ -1,22 +1,28 @@ package at.forsyte.apalache.tla.bmcmt import at.forsyte.apalache.tla.bmcmt.types._ -import at.forsyte.apalache.tla.lir.NameEx -import at.forsyte.apalache.tla.lir.convenience.tla -import at.forsyte.apalache.tla.lir.oper.TlaFunOper -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.lir.TypedPredefs._ +import at.forsyte.apalache.tla.lir.convenience.tla._ +import at.forsyte.apalache.tla.lir._ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestSymbStateRewriterSequence extends RewriterBase { + private val types = Map( + "i" -> IntT1(), + "I" -> SetT1(IntT1()), + "b" -> BoolT1(), + "Qi" -> SeqT1(IntT1()) + ) + // As sequences are not distinguishable from tuples, we need a type annotation. // In the not so far away future, a type inference engine would tell us, whether to construct a sequence or a tuple - test("""SE-SEQ-CTOR: <<>> <: Seq(Int)""") { - val tuple = tla.tuple() - val annotatedTuple = tla.withType(tuple, AnnotationParser.toTla(SeqT(IntT()))) + test("""<<>> as Seq(Int)""") { + val tup = tuple() + .typed(types, "Qi") - val state = new SymbState(annotatedTuple, arena, Binding()) + val state = new SymbState(tup, arena, Binding()) val nextState = create().rewriteUntilDone(state) assert(solverContext.sat()) nextState.ex match { @@ -30,11 +36,11 @@ class TestSymbStateRewriterSequence extends RewriterBase { } } - test("""SE-SEQ-CTOR: <<1, 2, 3>> <: Seq(Int)""") { - val tuple = tla.tuple(1.to(3).map(i => tla.int(i)): _*) - val annotatedTuple = tla.withType(tuple, AnnotationParser.toTla(SeqT(IntT()))) + test("""<<1, 2, 3>> as Seq(Int)""") { + val tup = tuple(1.to(3).map(int): _*) + .typed(types, "Qi") - val state = new SymbState(annotatedTuple, arena, Binding()) + val state = new SymbState(tup, arena, Binding()) val nextState = create().rewriteUntilDone(state) assert(solverContext.sat()) nextState.ex match { @@ -48,92 +54,86 @@ class TestSymbStateRewriterSequence extends RewriterBase { } } - test("""SE-SEQ-APP: (<<3, 4, 5>> <: Seq(Int))[2]""") { - val tuple = tla.tuple(3.to(5).map(i => tla.int(i)): _*) - val annotatedTuple = tla.withType(tuple, AnnotationParser.toTla(SeqT(IntT()))) + test("""(<<3, 4, 5>> as Seq(Int))[2]""") { + val tup = tuple(3.to(5).map(int): _*) + .typed(types, "Qi") - val state = new SymbState(annotatedTuple, arena, Binding()) + val state = new SymbState(tup, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) assert(solverContext.sat()) val seq = nextState.asCell - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.appFun(seq.toNameEx, tla.int(1)), tla.int(3)))) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.appFun(seq.toNameEx, tla.int(2)), tla.int(4)))) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.appFun(seq.toNameEx, tla.int(3)), tla.int(5)))) + val eq1 = eql(appFun(seq.toNameEx ? "Qi", int(1)) ? "i", int(3)) + .typed(types, "b") + assertTlaExAndRestore(rewriter, nextState.setRex(eq1)) + val eq2 = eql(appFun(seq.toNameEx ? "Qi", int(2)) ? "i", int(4)) + .typed(types, "b") + assertTlaExAndRestore(rewriter, nextState.setRex(eq2)) + val eq3 = eql(appFun(seq.toNameEx ? "Qi", int(3)) ? "i", int(5)) + .typed(types, "b") + assertTlaExAndRestore(rewriter, nextState.setRex(eq3)) } - test("""SE-SEQ-APP: (<<>> <: Seq(Int))[1]""") { + test("""(<<>> as Seq(Int))[1]""") { // regression: <<>>[1] should produce no contradiction, nor throw an exception - val tuple = tla.tuple() - val annotatedTuple = tla.withType(tuple, AnnotationParser.toTla(SeqT(IntT()))) + val tup = tuple() + .typed(types, "Qi") - val state = new SymbState(annotatedTuple, arena, Binding()) + val state = new SymbState(tup, arena, Binding()) val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) + val _ = rewriter.rewriteUntilDone(state) assert(solverContext.sat()) } - test("""SE-SEQ-HEAD: Head(<<3, 4, 5>> <: Seq(Int))""") { - val tuple = tla.tuple(3.to(5).map(i => tla.int(i)): _*) - val annotatedTuple = tla.withType(tuple, AnnotationParser.toTla(SeqT(IntT()))) - val seqApp = tla.head(annotatedTuple) + test("""Head(<<3, 4, 5>> as Seq(Int))""") { + val tup = tuple(3.to(5).map(int): _*) ? "Qi" + val seqHead = head(tup) ? "i" + val eq = eql(seqHead, int(3)) + .typed(types, "b") - val state = new SymbState(seqApp, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - assert(solverContext.sat()) - val result = nextState.asCell - assert(IntT() == result.cellType) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(result.toNameEx, tla.int(3)))) + val state = new SymbState(eq, arena, Binding()) + assertTlaExAndRestore(create(), state) } - test("""SE-SEQ-LEN: Len(<<3, 4, 5>> <: Seq(Int))""") { - val tuple = tla.tuple(3.to(5).map(i => tla.int(i)): _*) - val annotatedTuple = tla.withType(tuple, AnnotationParser.toTla(SeqT(IntT()))) - val seqApp = tla.len(annotatedTuple) + test("""Len(<<3, 4, 5>> <: Seq(Int))""") { + val tup = tuple(3.to(5).map(i => int(i)): _*) + .typed(types, "Qi") + val seqLen = len(tup) ? "i" + val eq = eql(seqLen, int(3)) + .typed(types, "b") - val state = new SymbState(seqApp, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - assert(solverContext.sat()) - val result = nextState.asCell - assert(IntT() == result.cellType) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(result.toNameEx, tla.int(3)))) + val state = new SymbState(eq, arena, Binding()) + assertTlaExAndRestore(create(), state) } - test("""SE-SEQ-TAIL: Tail(<<3, 4, 5>> <: Seq(Int))""") { - val tuple = tla.tuple(3.to(5).map(i => tla.int(i)): _*) - val annotatedTuple = tla.withType(tuple, AnnotationParser.toTla(SeqT(IntT()))) - val seqTail = tla.tail(annotatedTuple) + test("""Tail(<<3, 4, 5>> as Seq(Int))""") { + val tup = tuple(3.to(5).map(int): _*) ? "Qi" + val seqTail = tail(tup) ? "Qi" + val expected = tuple(int(4), int(5)) ? "Qi" + val eq = eql(seqTail, expected) + .typed(types, "b") - val state = new SymbState(seqTail, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - assert(solverContext.sat()) - val result = nextState.asCell - assert(SeqT(IntT()) == result.cellType) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.appFun(result.toNameEx, tla.int(1)), tla.int(4)))) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.appFun(result.toNameEx, tla.int(2)), tla.int(5)))) + val state = new SymbState(eq, arena, Binding()) + assertTlaExAndRestore(create(), state) } - test("""regression: Tail(<<>> <: Seq(Int)) does not unsat and its length is zero""") { - val emptyTuple = tla.tuple() - val annotatedTuple = tla.withType(emptyTuple, AnnotationParser.toTla(SeqT(IntT()))) - val seqTail = tla.tail(annotatedTuple) + test("""regression: Tail(<<>> as Seq(Int)) does not unsat and its length is zero""") { + val emptyTuple = tuple() + .typed(types, "Qi") + val seqTail = tail(emptyTuple) + .typed(types, "Qi") + val eq = eql(len(seqTail) ? "i", int(0)) + .typed(types, "i") - val state = new SymbState(seqTail, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - // in this case, Tail may return an arbitrary value, but it should not get stuck! - assert(solverContext.sat()) - // the length of the new sequence is 0, not -1 - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.int(0), tla.len(nextState.ex)))) + val state = new SymbState(eq, arena, Binding()) + assertTlaExAndRestore(create(), state) } - test("""SE-SEQ-SUBSEQ: SubSeq(S, 2, 4)""") { - val tuple = tla.tuple(3.to(6).map(i => tla.int(i)): _*) - val annotatedTuple = tla.withType(tuple, AnnotationParser.toTla(SeqT(IntT()))) - val subseqEx = tla.subseq(annotatedTuple, tla.int(2), tla.int(3)) + test("""SubSeq(S, 2, 4)""") { + val tup = tuple(3.to(6).map(int): _*) + .typed(types, "Qi") + val subseqEx = subseq(tup, int(2), int(3)) + .typed(types, "Qi") val state = new SymbState(subseqEx, arena, Binding()) val rewriter = create() @@ -141,122 +141,121 @@ class TestSymbStateRewriterSequence extends RewriterBase { assert(solverContext.sat()) val result = nextState.asCell assert(SeqT(IntT()) == result.cellType) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.appFun(result.toNameEx, tla.int(1)), tla.int(4)))) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.appFun(result.toNameEx, tla.int(2)), tla.int(5)))) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.len(result.toNameEx), tla.int(2)))) + val eq1 = eql(appFun(result.toNameEx ? "Qi", int(1)) ? "i", int(4)) + .typed(types, "b") + assertTlaExAndRestore(rewriter, nextState.setRex(eq1)) + val eq2 = eql(appFun(result.toNameEx ? "Qi", int(2)) ? "i", int(5)) + .typed(types, "b") + assertTlaExAndRestore(rewriter, nextState.setRex(eq2)) + val eq3 = eql(len(result.toNameEx ? "Qi") ? "i", int(2)) + .typed(types, "b") + assertTlaExAndRestore(rewriter, nextState.setRex(eq3)) } - test("""regression: SE-SEQ-SUBSEQ: SubSeq(S, 3, 1) does not unsat and has length 0""") { - val tuple = tla.tuple(3.to(6).map(i => tla.int(i)): _*) - val annotatedTuple = tla.withType(tuple, AnnotationParser.toTla(SeqT(IntT()))) - val subseqEx = tla.subseq(annotatedTuple, tla.int(3), tla.int(1)) + test("""regression: SubSeq(S, 3, 1) does not unsat and has length 0""") { + val tup = tuple(3.to(6).map(int): _*) + .typed(types, "Qi") + val subseqEx = subseq(tup, int(3), int(1)) + .typed(types, "Qi") + val eq = eql(len(subseqEx) ? "i", int(0)) + .typed(types, "b") - val state = new SymbState(subseqEx, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - // in this case, the solver should not be stuck by unsat, the value is simply arbitrary - assert(solverContext.sat()) - // the length of the new sequence is 0, not -1 - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.int(0), tla.len(nextState.ex)))) + val state = new SymbState(eq, arena, Binding()) + assertTlaExAndRestore(create(), state) } - test("""SE-SEQ-APPEND: Append(S, 10)""") { - val tuple = tla.tuple(4.to(5).map(i => tla.int(i)): _*) - val annotatedTuple = tla.withType(tuple, AnnotationParser.toTla(SeqT(IntT()))) - val append = tla.append(annotatedTuple, tla.int(10)) + test("""Append(S, 10)""") { + val tup = tuple(4.to(5).map(int): _*) + .typed(types, "Qi") + val seqAppend = append(tup, int(10)) + .typed(types, "Qi") - val state = new SymbState(append, arena, Binding()) + val state = new SymbState(seqAppend, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) assert(solverContext.sat()) val result = nextState.asCell assert(SeqT(IntT()) == result.cellType) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.appFun(result.toNameEx, tla.int(1)), tla.int(4)))) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.appFun(result.toNameEx, tla.int(2)), tla.int(5)))) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.appFun(result.toNameEx, tla.int(3)), tla.int(10)))) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.len(result.toNameEx), tla.int(3)))) + val eq1 = eql(appFun(result.toNameEx ? "Qi", int(1)) ? "i", int(4)) + .typed(types, "b") + assertTlaExAndRestore(rewriter, nextState.setRex(eq1)) + val eq2 = eql(appFun(result.toNameEx ? "Qi", int(2)) ? "i", int(5)) + .typed(types, "b") + assertTlaExAndRestore(rewriter, nextState.setRex(eq2)) + val eq3 = eql(appFun(result.toNameEx ? "Qi", int(3)) ? "i", int(10)) + .typed(types, "b") + assertTlaExAndRestore(rewriter, nextState.setRex(eq3)) + val eq4 = eql(len(result.toNameEx ? "Qi") ? "i", int(3)).typed(types, "b") + assertTlaExAndRestore(rewriter, nextState.setRex(eq4)) } - test("""SE-SEQ-APPEND: Append(SubSeq(S, 2, 3), 10)""") { - val tuple = tla.tuple(3.to(6).map(i => tla.int(i)): _*) - val annotatedTuple = tla.withType(tuple, AnnotationParser.toTla(SeqT(IntT()))) - val subseqEx = tla.subseq(annotatedTuple, tla.int(2), tla.int(3)) - val append = tla.append(subseqEx, tla.int(10)) + test("""Append(SubSeq(S, 2, 3), 10)""") { + val tup = tuple(3.to(6).map(int): _*) ? "Qi" + val subseqEx = subseq(tup, int(2), int(3)) ? "Qi" + val seqAppend = append(subseqEx, int(10)) + .typed(types, "Qi") - val state = new SymbState(append, arena, Binding()) + val state = new SymbState(seqAppend, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) assert(solverContext.sat()) val result = nextState.asCell assert(SeqT(IntT()) == result.cellType) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.appFun(result.toNameEx, tla.int(1)), tla.int(4)))) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.appFun(result.toNameEx, tla.int(2)), tla.int(5)))) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.appFun(result.toNameEx, tla.int(3)), tla.int(10)))) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.len(result.toNameEx), tla.int(3)))) + val eq1 = eql(appFun(result.toNameEx ? "Qi", int(1)) ? "i", int(4)) + .typed(types, "b") + assertTlaExAndRestore(rewriter, nextState.setRex(eq1)) + val eq2 = eql(appFun(result.toNameEx ? "Qi", int(2)) ? "i", int(5)) + .typed(types, "b") + assertTlaExAndRestore(rewriter, nextState.setRex(eq2)) + val eq3 = eql(appFun(result.toNameEx ? "Qi", int(3)) ? "i", int(10)) + .typed(types, "b") + assertTlaExAndRestore(rewriter, nextState.setRex(eq3)) + val eq4 = eql(len(result.toNameEx ? "Qi") ? "i", int(3)).typed(types, "b") + assertTlaExAndRestore(rewriter, nextState.setRex(eq4)) } - test("""SE-SEQ-EQ: <<4, 5>> = SubSeq(<<3, 4, 5, 6>>, 2, 3)""") { - val tuple3456 = tla.tuple(3.to(6).map(i => tla.int(i)): _*) - val annot3456 = tla.withType(tuple3456, AnnotationParser.toTla(SeqT(IntT()))) - val subseqEx = tla.subseq(annot3456, tla.int(2), tla.int(3)) - val tuple45 = tla.tuple(4.to(5).map(i => tla.int(i)): _*) - val annot45 = tla.withType(tuple45, AnnotationParser.toTla(SeqT(IntT()))) - val eq = tla.eql(annot45, subseqEx) + test("""<<4, 5>> = SubSeq(<<3, 4, 5, 6>>, 2, 3)""") { + val tup3456 = tuple(3.to(6).map(int): _*) ? "Qi" + val subseqEx = subseq(tup3456, int(2), int(3)) ? "Qi" + val tup45 = tuple(4.to(5).map(int): _*) ? "Qi" + val eq = eql(tup45, subseqEx) + .typed(types, "b") val state = new SymbState(eq, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - assert(solverContext.sat()) - val result = nextState.asCell - assert(BoolT() == result.cellType) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.bool(true), nextState.ex))) + assertTlaExAndRestore(create(), state) } - test("""SE-SEQ-DOMAIN: DOMAIN SubSeq(<<3, 4, 5, 6>>, 2, 3) = {2, 3}""") { - val tuple3456 = tla.tuple(3.to(6).map(i => tla.int(i)): _*) - val annot3456 = tla.withType(tuple3456, AnnotationParser.toTla(SeqT(IntT()))) - val subseqEx = tla.subseq(annot3456, tla.int(2), tla.int(3)) - val domEx = tla.dom(subseqEx) + test("""DOMAIN SubSeq(<<3, 4, 5, 6>>, 2, 3) equals {2, 3}""") { + val tup3456 = tuple(3.to(6).map(int): _*) ? "Qi" + val subseqEx = subseq(tup3456, int(2), int(3)) ? "Qi" + val domEx = dom(subseqEx) ? "I" + val eq = eql(domEx, enumSet(int(2), int(3)) ? "I") + .typed(types, "b") - val state = new SymbState(domEx, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - assert(solverContext.sat()) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.enumSet(tla.int(2), tla.int(3)), nextState.ex))) + val state = new SymbState(eq, arena, Binding()) + assertTlaExAndRestore(create(), state) } - test("""SEQ-CONCAT: <<9, 10>> \o SubSeq(S, 2, 3)""") { - val tuple3_6 = tla.tuple(3.to(6).map(i => tla.int(i)): _*).untyped() - val seqT = AnnotationParser.toTla(SeqT(IntT())) - val annotatedTuple = tla.withType(tuple3_6, seqT).untyped() - val subseq = tla.subseq(annotatedTuple, tla.int(2), tla.int(3)).untyped() // <<4, 5>> - val tuple9_10 = tla.tuple(9.to(10).map(i => tla.int(i)): _*).untyped() - val annotatedTuple9_10 = tla.withType(tuple9_10, seqT).untyped() - val concat = tla.concat(annotatedTuple9_10, subseq).untyped() + test("""<<9, 10>> \o SubSeq(S, 2, 3)""") { + val tup3_6 = tuple(3.to(6).map(int): _*) ? "Qi" + val subseqRes = subseq(tup3_6, int(2), int(3)) ? "Qi" // <<4, 5>> + val tup9_10 = tuple(int(9), int(10)) ? "Qi" + val concatRes = concat(tup9_10, subseqRes) ? "Qi" + val expected = tuple(int(9), int(10), int(4), int(5)) ? "Qi" + val eq = eql(concatRes, expected) + .typed(types, "b") - val state = new SymbState(concat, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - assert(solverContext.sat()) - val result = nextState.asCell - assert(SeqT(IntT()) == result.cellType) - - val tupleExpected = tla.tuple(tla.int(9), tla.int(10), tla.int(4), tla.int(5)) - - val expected = tla.withType(tupleExpected, seqT).untyped() - - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(expected, nextState.ex))) + val state = new SymbState(eq, arena, Binding()) + assertTlaExAndRestore(create(), state) } - test("""regression: SEQ-CONCAT: <<9, 10>> \o Tail(<<>>) does not unsat""") { - val seqT = AnnotationParser.toTla(SeqT(IntT())) - val empty = tla.withType(tla.tuple(), seqT) - val t9_10 = tla.tuple(9.to(10).map(i => tla.int(i)): _*) - val tuple9_10 = tla.withType(t9_10, seqT) + test("""regression: <<9, 10>> \o Tail(<<>>) does not unsat""") { + val t9_10 = tuple(int(9), int(10)) ? "Qi" // Tail(<<>>) produces some undefined value. In this case, \o should also produce an undefined value. - val concat = tla.concat(tuple9_10, tla.tail(empty)) + val concatRes = concat(t9_10, tail(tuple() ? "Qi") ? "Qi") + .typed(types, "Qi") - val state = new SymbState(concat, arena, Binding()) + val state = new SymbState(concatRes, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) // the result is undefined, but it should be sat @@ -264,6 +263,4 @@ class TestSymbStateRewriterSequence extends RewriterBase { } // for PICK see TestCherryPick - - // TODO: except } diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterSet.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterSet.scala index 3902a02692..d93aed398e 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterSet.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterSet.scala @@ -1,32 +1,47 @@ package at.forsyte.apalache.tla.bmcmt +import at.forsyte.apalache.tla.bmcmt.smt.{PreproSolverContext, SolverConfig, Z3SolverContext} import at.forsyte.apalache.tla.bmcmt.types._ -import at.forsyte.apalache.tla.bmcmt.types.eager.TrivialTypeFinder +import at.forsyte.apalache.tla.lir.TypedPredefs._ import at.forsyte.apalache.tla.lir._ -import at.forsyte.apalache.tla.lir.convenience.tla -import at.forsyte.apalache.tla.lir.oper._ -import at.forsyte.apalache.tla.lir.values.{TlaBool, TlaInt, TlaIntSet, TlaNatSet} -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.lir.convenience.tla._ +import at.forsyte.apalache.tla.pp.TlaInputError import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) -class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { - private def emptySetWithType(elemT: CellT): TlaEx = - tla.withType(tla.enumSet(), AnnotationParser.toTla(FinSetT(elemT))) - - test("""SE-SET-CTOR[1-2]: {x, y, z} ~~> c_set""") { - val ex = OperEx(TlaSetOper.enumSet, NameEx("x"), NameEx("y"), NameEx("z")) +class TestSymbStateRewriterSet extends RewriterBase { + private val types = Map( + "i" -> IntT1(), + "I" -> SetT1(IntT1()), + "II" -> SetT1(SetT1(IntT1())), + "III" -> SetT1(SetT1(SetT1(IntT1()))), + "IV" -> SetT1(SetT1(SetT1(SetT1(IntT1())))), + "V" -> SetT1(SetT1(SetT1(SetT1(SetT1(IntT1()))))), + "b" -> BoolT1(), + "B" -> SetT1(BoolT1()), + "i_to_b" -> FunT1(IntT1(), BoolT1()), + "ib" -> TupT1(IntT1(), BoolT1()), + "IB" -> SetT1(TupT1(IntT1(), BoolT1())) + ) + + test("""{ x, y, z } ~~> c_set""") { + val ex = enumSet(name("x") ? "i", name("y") ? "b", name("z") ? "b") + .typed(types, "B") val binding = Binding("x" -> arena.cellFalse(), "y" -> arena.cellTrue(), "z" -> arena.cellFalse()) val state = new SymbState(ex, arena, binding) create().rewriteOnce(state) match { case SymbStateRewriter.Continue(nextState) => nextState.ex match { - case set @ NameEx(name) => - solverContext.assertGroundExpr(OperEx(TlaSetOper.in, arena.cellFalse().toNameEx, set)) + case set @ NameEx(_) => + val falseInSet = in(arena.cellFalse().toNameEx ? "b", set ? "B") + .typed(types, "b") + solverContext.assertGroundExpr(falseInSet) assert(solverContext.sat()) + val notTrueInSet = not(in(arena.cellTrue().toNameEx ? "b", set ? "B") ? "b") + .typed(types, "b") solverContext - .assertGroundExpr(OperEx(TlaBoolOper.not, OperEx(TlaSetOper.in, arena.cellTrue().toNameEx, set))) + .assertGroundExpr(notTrueInSet) assert(!solverContext.sat()) case _ => @@ -38,13 +53,14 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { } } - test("""SE-SET-CTOR[1-2]: {1, 3, 5} ~~> c_set""") { - val ex = OperEx(TlaSetOper.enumSet, ValEx(TlaInt(1)), ValEx(TlaInt(3)), ValEx(TlaInt(5))) + test("""{1, 3, 5} ~~> c_set""") { + val ex = enumSet(int(1), int(3), int(5)) + .typed(types, "I") val state = new SymbState(ex, arena, Binding()) create().rewriteOnce(state) match { case SymbStateRewriter.Continue(nextState) => nextState.ex match { - case set @ NameEx(name) => + case NameEx(_) => assert(solverContext.sat()) case _ => fail("Unexpected rewriting result") @@ -55,19 +71,17 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { } } - test("""SE-SET-IN1: {} \in {} ~~> $B$0""") { - def mkSet(elems: TlaEx*) = OperEx(TlaSetOper.enumSet, elems: _*) - - val ex = OperEx(TlaSetOper.in, emptySetWithType(IntT()), emptySetWithType(FinSetT(IntT()))) + test("""{} \in {}""") { + val ex = in(enumSet() ? "I", enumSet() ? "II") + .typed(types, "b") val state = new SymbState(ex, arena, Binding()) val nextState = create().rewriteUntilDone(state) assert(nextState.arena.cellFalse().toNameEx == nextState.ex) } - test("""SE-SET-IN1: 3 \in {1, 3, 5} ~~> $B$k""") { - def mkSet(elems: TlaEx*) = OperEx(TlaSetOper.enumSet, elems: _*) - - val ex = OperEx(TlaSetOper.in, ValEx(TlaInt(3)), mkSet(ValEx(TlaInt(1)), ValEx(TlaInt(3)), ValEx(TlaInt(5)))) + test("""3 \in {1, 3, 5}""") { + val ex = in(int(3), enumSet(int(1), int(3), int(5)) ? "I") + .typed(types, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -77,7 +91,7 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { solverContext.assertGroundExpr(predEx) assert(solverContext.sat()) rewriter.pop() - solverContext.assertGroundExpr(OperEx(TlaBoolOper.not, predEx)) + solverContext.assertGroundExpr(not(predEx ? "b").typed(types, "b")) assert(!solverContext.sat()) case _ => @@ -85,9 +99,10 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { } } - test("""SE-SET-IN1: {3} \in {{1}, {3}, {5}} ~~> $B$k""") { - val ex = tla.in(tla.enumSet(tla.int(3)), - tla.enumSet(tla.enumSet(tla.int(1)), tla.enumSet(tla.int(3)), tla.enumSet(tla.int(5)))) + test("""{3} \in {{1}, {3}, {5}}""") { + val ex = in(enumSet(int(3) ? "i") ? "I", + enumSet(enumSet(int(1)) ? "I", enumSet(int(3)) ? "I", enumSet(int(5)) ? "I") ? "II") + .typed(types, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() @@ -98,7 +113,7 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { solverContext.assertGroundExpr(predEx) assert(solverContext.sat()) rewriter.pop() - solverContext.assertGroundExpr(OperEx(TlaBoolOper.not, predEx)) + solverContext.assertGroundExpr(not(predEx ? "b").typed(types, "b")) assert(!solverContext.sat()) case _ => @@ -106,10 +121,9 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { } } - test("""SE-SET-IN1: 2 \in {1, 3, 5} ~~> $B$k""") { - def mkSet(elems: TlaEx*) = OperEx(TlaSetOper.enumSet, elems: _*) - - val ex = OperEx(TlaSetOper.in, ValEx(TlaInt(2)), mkSet(ValEx(TlaInt(1)), ValEx(TlaInt(3)), ValEx(TlaInt(5)))) + test("""2 \in {1, 3, 5}""") { + val ex = in(int(2), enumSet(int(1), int(3), int(5)) ? "I") + .typed(types, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -119,7 +133,7 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { solverContext.assertGroundExpr(predEx) assert(!solverContext.sat()) rewriter.pop() - solverContext.assertGroundExpr(OperEx(TlaBoolOper.not, predEx)) + solverContext.assertGroundExpr(not(predEx ? "b").typed(types, "b")) assert(solverContext.sat()) case _ => @@ -127,8 +141,9 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { } } - test("""SE-SET-IN-INT: 2 \in Int""") { - val ex = OperEx(TlaSetOper.in, tla.int(2), ValEx(TlaIntSet)) + test("""2 \in Int""") { + val ex = in(int(2), intSet() ? "I") + .typed(types, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -136,8 +151,9 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { assertTlaExAndRestore(rewriter, nextState) } - test("""SE-SET-IN-Nat: 2 \in Nat""") { - val ex = OperEx(TlaSetOper.in, tla.int(2), ValEx(TlaNatSet)) + test("""2 \in Nat""") { + val ex = in(int(2), natSet() ? "I") + .typed(types, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -145,28 +161,19 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { assertTlaExAndRestore(rewriter, nextState) } - test("""SE-SET-IN-Nat: -1 \in Nat""") { - val ex = OperEx(TlaSetOper.in, tla.int(-1), ValEx(TlaNatSet)) + test("""-1 \in Nat""") { + val ex = in(int(-1), natSet() ? "I") + .typed(types, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) assert(solverContext.sat()) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.not(nextState.ex))) - } - - test("""type inference 3 \in {{1}, {3}, {5}}""") { - // this test worked in the previous versions, but now it just reports a type inference error - val ex = tla.in(tla.int(3), tla.enumSet(tla.enumSet(tla.int(1)), tla.enumSet(tla.int(3)), tla.enumSet(tla.int(5)))) - - val state = new SymbState(ex, arena, Binding()) - val rewriter = create() - assertThrows[TypeInferenceException] { - rewriter.rewriteUntilDone(state) - } + assertTlaExAndRestore(rewriter, nextState.setRex(not(nextState.ex ? "b").typed(types, "b"))) } - test("""SE-SET-NOTIN1: ~({} \in {}) ~~> $B$1""") { - val ex = tla.not(tla.in(emptySetWithType(FinSetT(IntT())), emptySetWithType(FinSetT(FinSetT(IntT()))))) + test("""~({} \in {})""") { + val ex = not(in(enumSet() ? "I", enumSet() ? "II") ? "b") + .typed(types, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -175,15 +182,15 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { assert(solverContext.sat()) solverContext.pop() solverContext.push() - solverContext.assertGroundExpr(tla.not(nextState.ex)) + solverContext.assertGroundExpr(not(nextState.ex ? "b").typed(types, "b")) assert(!solverContext.sat()) solverContext.pop() } - test("""SE-SET-IN2: \FALSE \in {\FALSE, \TRUE} ~~> b_new""") { + test("""FALSE \in {FALSE, TRUE}""") { val ex = - OperEx(TlaSetOper.in, ValEx(TlaBool(false)), - OperEx(TlaSetOper.enumSet, ValEx(TlaBool(false)), ValEx(TlaBool(true)))) + in(bool(false), enumSet(bool(false), bool(true)) ? "B") + .typed(types, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() rewriter.rewriteOnce(state) match { @@ -191,7 +198,7 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { nextState.ex match { case predEx @ NameEx(name) => rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaBoolOper.not, predEx)) + solverContext.assertGroundExpr(not(predEx ? "b").typed(types, "b")) assert(!solverContext.sat()) rewriter.pop() solverContext.assertGroundExpr(predEx) @@ -206,15 +213,16 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { } } - test("""SE-SET-NOTIN1: ~(\FALSE \in {\FALSE, \TRUE}) ~~> b_new""") { + test("""~(FALSE \in {FALSE, TRUE})""") { val ex = - tla.not(tla.in(tla.bool(false), tla.enumSet(tla.bool(false), tla.bool(true)))) + not(in(bool(false), enumSet(bool(false), bool(true)) ? "B") ? "b") + .typed(types, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() rewriter.rewriteUntilDone(state).ex match { case predEx @ NameEx(name) => rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaBoolOper.not, predEx)) + solverContext.assertGroundExpr(not(predEx ? "b").typed(types, "b")) assert(solverContext.sat()) rewriter.pop() solverContext.assertGroundExpr(predEx) @@ -225,11 +233,12 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { } } - test("""SE-SET-IN3: c_i: Bool \in {\TRUE, \TRUE} ~~> b_new""") { + test("""c_i \in {TRUE, TRUE}""") { arena = arena.appendCell(BoolT()) val cell = arena.topCell val ex = - OperEx(TlaSetOper.in, cell.toNameEx, OperEx(TlaSetOper.enumSet, ValEx(TlaBool(true)), ValEx(TlaBool(true)))) + in(cell.toNameEx ? "b", enumSet(bool(true), bool(true)) ? "B") + .typed(types, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() rewriter.rewriteOnce(state) match { @@ -238,14 +247,14 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { case predEx @ NameEx(name) => rewriter.push() // cell = \TRUE - solverContext.assertGroundExpr(OperEx(TlaOper.eq, arena.cellTrue().toNameEx, cell.toNameEx)) + solverContext.assertGroundExpr(eql(arena.cellTrue().toNameEx ? "b", cell.toNameEx ? "b").typed(types, "b")) // and membership holds true solverContext.assertGroundExpr(predEx) assert(solverContext.sat()) rewriter.pop() // another query // cell = \FALSE - solverContext.assertGroundExpr(OperEx(TlaOper.eq, arena.cellFalse().toNameEx, cell.toNameEx)) + solverContext.assertGroundExpr(eql(arena.cellFalse().toNameEx ? "b", cell.toNameEx ? "b").typed(types, "b")) // and membership holds true solverContext.assertGroundExpr(predEx) // contradiction @@ -260,9 +269,10 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { } } - test("""SE-SET-IN1 (shortcut): 1 \in {1} ~~> $B$k""") { - // there is a special shortcut rule for singleton sets, which had a bug - val ex = tla.in(tla.int(1), tla.enumSet(tla.int(1))) + test("""1 \in {1}""") { + // regression: there is a special shortcut rule for singleton sets, which had a bug + val ex = in(int(1), enumSet(int(1)) ? "I") + .typed(types, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() @@ -275,7 +285,7 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { assert(solverContext.sat()) rewriter.pop() rewriter.push() - solverContext.assertGroundExpr(OperEx(TlaBoolOper.not, predEx)) + solverContext.assertGroundExpr(not(predEx ? "b").typed(types, "b")) assert(!solverContext.sat()) rewriter.pop() @@ -284,25 +294,26 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { } } - test("""SE-SET-NOTIN1: c_i: ~(Bool \in {\TRUE, \TRUE}) ~~> c_pred""") { + test("""~(Bool \in {TRUE, TRUE})""") { arena = arena.appendCell(BoolT()) val cell = arena.topCell val ex = - tla.not(tla.in(cell.toNameEx, tla.enumSet(tla.bool(true), tla.bool(true)))) + not(in(cell.toNameEx ? "b", enumSet(bool(true), bool(true)) ? "B") ? "b") + .typed(types, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() rewriter.rewriteUntilDone(state).ex match { case predEx @ NameEx(name) => rewriter.push() - // cell = \TRUE - solverContext.assertGroundExpr(OperEx(TlaOper.eq, arena.cellTrue().toNameEx, cell.toNameEx)) + // cell = TRUE + solverContext.assertGroundExpr(eql(arena.cellTrue().toNameEx ? "b", cell.toNameEx ? "b").typed(types, "b")) // and membership holds true solverContext.assertGroundExpr(predEx) assert(!solverContext.sat()) rewriter.pop() // another query // cell = \FALSE - solverContext.assertGroundExpr(OperEx(TlaOper.eq, arena.cellFalse().toNameEx, cell.toNameEx)) + solverContext.assertGroundExpr(eql(arena.cellFalse().toNameEx ? "b", cell.toNameEx ? "b").typed(types, "b")) // and membership holds true solverContext.assertGroundExpr(predEx) // no contradiction here @@ -313,20 +324,23 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { } } - test("""SE-SET-IN3: {{}, {{}, {}}} \in {{}, {{}, {{}, {}}}} ~~> b_new""") { - def mkSet(elems: TlaEx*) = OperEx(TlaSetOper.enumSet, elems: _*) - def intSet() = emptySetWithType(IntT()) - def int2Set() = emptySetWithType(FinSetT(IntT())) - def int3Set() = emptySetWithType(FinSetT(FinSetT(IntT()))) + test("""{{}, {{}, {}}} \in {{}, {{}, {{}, {}}}}""") { + def intSet() = enumSet().typed(types, "I") + + def int2Set() = enumSet().typed(types, "II") - val left = mkSet(int2Set(), mkSet(intSet(), intSet())) - val right = mkSet(int3Set(), mkSet(int2Set(), mkSet(intSet(), intSet()))) - val ex = OperEx(TlaSetOper.in, left, right) + def int3Set() = enumSet().typed(types, "III") + + val left = enumSet(int2Set(), enumSet(intSet(), intSet()) ? "II") + .typed(types, "III") + val right = enumSet(int3Set(), enumSet(int2Set() ? "II", enumSet(intSet(), intSet()) ? "II") ? "III") + .typed(types, "IV") + val ex = in(left, right).typed(BoolT1()) val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { - case predEx @ NameEx(name) => + case predEx @ NameEx(_) => rewriter.push() // and membership holds true solverContext.assertGroundExpr(predEx) @@ -334,7 +348,7 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { rewriter.pop() // another query // and membership does not hold - solverContext.assertGroundExpr(OperEx(TlaBoolOper.not, predEx)) + solverContext.assertGroundExpr(not(predEx ? "b").typed(types, "b")) assert(!solverContext.sat()) case _ => @@ -342,28 +356,32 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { } } - test("""SE-SET-IN3: {{}, {{{}}}} \in {{}, {{}, {{}}} ~~> b_new""") { - def mkSet(elems: TlaEx*) = OperEx(TlaSetOper.enumSet, elems: _*) - def intSet() = emptySetWithType(IntT()) - def int2Set() = emptySetWithType(FinSetT(IntT())) - def int3Set() = emptySetWithType(FinSetT(FinSetT(IntT()))) - def int4Set() = emptySetWithType(FinSetT(FinSetT(FinSetT(IntT())))) + test("""{{}, {{{}}}} \in {{}, {{}, {{}}}""") { + def intSet() = enumSet().typed(types, "I") - val left = mkSet(int3Set(), mkSet(mkSet(intSet()))) - val right = mkSet(int4Set(), mkSet(int3Set(), mkSet(int2Set()))) - val ex = OperEx(TlaSetOper.in, left, right) + def int2Set() = enumSet().typed(types, "II") + + def int3Set() = enumSet().typed(types, "III") + + def int4Set() = enumSet().typed(types, "IV") + + val left = enumSet(int3Set(), enumSet(enumSet(intSet()) ? "II") ? "III") + .typed(types, "IV") + val right = enumSet(int4Set(), enumSet(int3Set(), enumSet(int2Set()) ? "III") ? "IV") + .typed(types, "V") + val ex = in(left, right).typed(BoolT1()) val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { - case predEx @ NameEx(name) => + case predEx @ NameEx(_) => rewriter.push() // set membership should not hold solverContext.assertGroundExpr(predEx) assert(!solverContext.sat()) rewriter.pop() // its negation holds true - solverContext.assertGroundExpr(OperEx(TlaBoolOper.not, predEx)) + solverContext.assertGroundExpr(not(predEx ? "b").typed(types, "b")) assert(solverContext.sat()) case _ => @@ -371,18 +389,18 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { } } - test("""SE-SET-EQ1: {{}} = {} ~~> $B$... (false)""") { - def mkSet(elems: TlaEx*) = OperEx(TlaSetOper.enumSet, elems: _*) + test("""{{}} = {} ~~> (false)""") { + def intSet() = enumSet().typed(types, "I") - def intSet() = emptySetWithType(IntT()) // empty sets need types - def int2Set() = emptySetWithType(FinSetT(IntT())) // empty sets need types + def int2Set() = enumSet().typed(types, "II") - val ex = tla.eql(tla.enumSet(intSet()), int2Set()) + val ex = eql(enumSet(intSet()) ? "II", int2Set()) + .typed(types, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { - case predEx @ NameEx(name) => + case predEx @ NameEx(_) => rewriter.push() // not equal solverContext.assertGroundExpr(predEx) @@ -393,14 +411,18 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { } } - test("""SE-SET-EQ1: {{}, {{}}} = {{}, {{{}}} ~~> $B$... (false)""") { - def intSet() = emptySetWithType(IntT()) - def int2Set() = emptySetWithType(FinSetT(IntT())) - def int3Set() = emptySetWithType(FinSetT(FinSetT(IntT()))) + test("""{{}, {{}}} = {{}, {{{}}} ~~> (false)""") { + def intSet() = enumSet().typed(types, "I") + + def int2Set() = enumSet().typed(types, "II") + + def int3Set() = enumSet().typed(types, "III") - val left = tla.enumSet(int3Set(), tla.enumSet(int2Set())) - val right = tla.enumSet(int3Set(), tla.enumSet(tla.enumSet(intSet()))) - val ex = OperEx(TlaOper.eq, left, right) + val left = enumSet(int3Set(), enumSet(int2Set()) ? "III") + .typed(types, "IV") + val right = enumSet(int3Set(), enumSet(enumSet(intSet()) ? "II") ? "III") + .typed(types, "IV") + val ex = eql(left, right).typed(BoolT1()) val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -416,18 +438,22 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { } } - test("""SE-SET-EQ1: {{}, {{}}} = {{}, {{}} ~~> $B$... (true)""") { - def intSet() = emptySetWithType(IntT()) - def int2Set() = emptySetWithType(FinSetT(IntT())) + test("""{{}, {{}}} = {{}, {{}} ~~> (true)""") { + def intSet() = enumSet().typed(types, "I") - val left = tla.enumSet(int2Set(), tla.enumSet(intSet())) - val right = tla.enumSet(int2Set(), tla.enumSet(intSet())) - val ex = OperEx(TlaOper.eq, left, right) + def int2Set() = enumSet().typed(types, "II") + + val left = enumSet(int2Set(), enumSet(intSet()) ? "II") + .typed(types, "III") + val right = enumSet(int2Set() ? "II", enumSet(intSet()) ? "II") + .typed(types, "III") + val ex = eql(left, right) + .typed(BoolT1()) val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { - case predEx @ NameEx(name) => + case predEx @ NameEx(_) => rewriter.push() // not equal solverContext.assertGroundExpr(predEx) @@ -438,15 +464,16 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { } } - test("""SE-SET-EQ1: {} = {1} \ {1} ~~> $B$... (true)""") { - def intSet() = emptySetWithType(IntT()) - val setOf1 = tla.enumSet(tla.int(1)) + test("""{} = {1} \ {1} ~~> (true)""") { + def intSet() = enumSet().typed(types, "I") - def dynEmpty(left: TlaEx): TlaEx = { - tla.filter(tla.name("t"), left, tla.bool(false)) - } + val setOf1 = enumSet(int(1)).typed(types, "I") + + val dynEmpty = + filter(name("t") ? "i", setOf1, bool(false)) + .typed(types, "I") - val ex = OperEx(TlaOper.eq, intSet(), dynEmpty(setOf1)) + val ex = eql(intSet(), dynEmpty).typed(BoolT1()) val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -457,7 +484,7 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { solverContext.assertGroundExpr(predEx) assert(solverContext.sat()) rewriter.pop() - solverContext.assertGroundExpr(tla.not(predEx)) + solverContext.assertGroundExpr(not(predEx ? "b").typed(types, "b")) assert(!solverContext.sat()) case _ => @@ -465,31 +492,39 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { } } - test("""type incorrect {1} \ {1} = {FALSE} \ {FALSE}""") { + test("""type incorrect { i \in {1}: FALSE } = { b \in {FALSE}: FALSE }""") { // This test worked in the previous versions. // Now we enforce type correctness, and reject this expression right after type checking. - val setOfOne = tla.enumSet(tla.int(1)) - val setOfFalse = tla.enumSet(tla.bool(false)) - val ex = OperEx(TlaOper.eq, tla.setminus(setOfFalse, setOfFalse), tla.setminus(setOfOne, setOfOne)) + // Although we keep this test, it cannot originate from a well-typed TLA+ code. + val intFilter = filter(name("i") ? "i", enumSet(int(1)) ? "I", bool(false)) + .typed(types, "I") + val boolFilter = filter(name("b") ? "B", enumSet(bool(false)) ? "B", bool(false)) + .typed(types, "B") + val ex = eql(intFilter ? "I", boolFilter ? "B") + .typed(types, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() - assertThrows[TypeInferenceException] { + assertThrows[MalformedTlaError] { rewriter.rewriteUntilDone(state) } } - test("""SE-SET-NE1: ~({{}, {{}}} = {{}, {{}}}) ~~> $B$... (false)""") { - def intSet() = emptySetWithType(IntT()) - def int2Set() = emptySetWithType(FinSetT(IntT())) + test("""~({{}, {{}}} = {{}, {{}}})""") { + def intSet() = enumSet().typed(types, "I") - val left = tla.enumSet(int2Set(), tla.enumSet(intSet())) - val right = tla.enumSet(int2Set(), tla.enumSet(intSet())) - val ex = tla.not(tla.eql(left, right)) + def int2Set() = enumSet().typed(types, "II") + + val left = enumSet(int2Set(), enumSet(intSet()) ? "II") + .typed(types, "III") + val right = enumSet(int2Set(), enumSet(intSet()) ? "II") + .typed(types, "III") + val ex = not(eql(left, right) ? "b") + .typed(types, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { - case predEx @ NameEx(name) => + case predEx @ NameEx(_) => rewriter.push() // not equal solverContext.assertGroundExpr(predEx) @@ -499,13 +534,15 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { fail("Unexpected rewriting result") } } - test("""SE-SET-FILTER[1-2]: {x \in {1,2,3,4} : x % 2 = 0} ~~> {2, 4}""") { - def mkSet(elems: TlaEx*) = OperEx(TlaSetOper.enumSet, elems: _*) - - val set = mkSet(ValEx(TlaInt(1)), ValEx(TlaInt(2)), ValEx(TlaInt(3)), ValEx(TlaInt(4))) - val xMod2 = OperEx(TlaArithOper.mod, NameEx("x"), ValEx(TlaInt(2))) - val filter = OperEx(TlaOper.eq, xMod2, ValEx(TlaInt(0))) - val ex = OperEx(TlaSetOper.filter, NameEx("x"), set, filter) + test("""{x \in {1,2,3,4} : x % 2 = 0} ~~> {2, 4}""") { + val set = enumSet(int(1), int(2), int(3), int(4)) + .typed(types, "I") + val xMod2 = mod(name("x") ? "i", int(2)) + .typed(types, "i") + val pred = eql(xMod2, int(0)) + .typed(types, "i") + val ex = filter(name("x") ? "i", set, pred) + .typed(types, "I") val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -520,24 +557,26 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { } } - test("""SE-SET-FILTER[1-2]: 2 \in {x \in {1,2,3,4} : x < 3} ~~> $B$k""") { - def mkSet(elems: TlaEx*) = OperEx(TlaSetOper.enumSet, elems: _*) - - val set = mkSet(ValEx(TlaInt(1)), ValEx(TlaInt(2)), ValEx(TlaInt(3)), ValEx(TlaInt(4))) - val filter = tla.lt(tla.name("x"), tla.int(3)) - val filteredSet = OperEx(TlaSetOper.filter, NameEx("x"), set, filter) - val inFilteredSet = OperEx(TlaSetOper.in, ValEx(TlaInt(2)), filteredSet) + test("""2 \in {x \in {1,2,3,4} : x < 3}""") { + val set = enumSet(int(1), int(2), int(3), int(4)) + .typed(types, "I") + val pred = lt(name("x") ? "i", int(3)) + .typed(types, "b") + val filteredSet = filter(name("x") ? "i", set, pred) + .typed(types, "I") + val inFilteredSet = in(int(2), filteredSet) + .typed(types, "b") val state = new SymbState(inFilteredSet, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { - case membershipEx @ NameEx(name) => + case membershipEx @ NameEx(_) => rewriter.push() solverContext.assertGroundExpr(membershipEx) assert(solverContext.sat()) rewriter.pop() - solverContext.assertGroundExpr(OperEx(TlaBoolOper.not, membershipEx)) + solverContext.assertGroundExpr(not(membershipEx ? "b").typed(types, "b")) assert(!solverContext.sat()) case _ => @@ -545,115 +584,92 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { } } - // until the real type inference is implemented - ignore("""SE-SET-FILTER[1-2]: LET X = {1, 2} \cap {2} IN {} = {x \in X : [y \in X |-> TRUE][x]} ~~> $B$k""") { - // regression - val filter = tla.appFun(tla.funDef(tla.bool(true), tla.name("y"), tla.name("Oper:X")), tla.name("x")) - val filteredSet = tla.filter(tla.name("x"), tla.name("Oper:X"), filter) - val set12 = tla.enumSet(tla.int(1), tla.int(2)) - val set2 = tla.enumSet(tla.int(2)) - val ex = - OperEx(BmcOper.skolem, - tla.letIn(tla.eql(tla.enumSet(), filteredSet), tla.declOp("X", tla.cap(set12, set2)).untypedOperDecl())) + test("""{Q \in Expand(SUBSET {1,2,3}) : ~(2 \in Q)}""") { + val set = enumSet(1.to(3).map(int): _*).typed(types, "I") - val state = new SymbState(ex, arena, Binding()) - val rewriter = new SymbStateRewriterImpl(solverContext, new TrivialTypeFinder()) - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case membershipEx @ NameEx(name) => - rewriter.push() - val failPreds = nextState.arena.findCellsByType(FailPredT()) - val failureOccurs = tla.or(failPreds.map(_.toNameEx): _*) - solverContext.assertGroundExpr(failureOccurs) - assert(!solverContext.sat()) // no failure should be possible - - case _ => - fail("Unexpected rewriting result") - } - } - - test("""SE-SET-FILTER: {Q \in Expand(SUBSET {1,2,3}) : ~(2 \in Q)}""") { - val set = tla.enumSet(1.to(3).map(tla.int): _*) - - val predEx = tla.not(tla.in(tla.int(2), tla.name("Q"))) - val expandedPowSet = OperEx(BmcOper.expand, tla.powSet(set)) - val ex = tla.filter(tla.name("Q"), expandedPowSet, predEx) + val predEx = not(in(int(2), name("Q") ? "I") ? "b") + .typed(types, "b") + val expandedPowSet = apalacheExpand(powSet(set) ? "II") + val ex = filter(name("Q") ? "I", expandedPowSet ? "II", predEx ? "b") + .typed(types, "II") val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) assert(solverContext.sat()) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.in(tla.enumSet(tla.int(1), tla.int(3)), nextState.ex))) - assertTlaExAndRestore(rewriter, - nextState.setRex(tla.not(tla.in(tla.enumSet(tla.int(1), tla.int(2)), nextState.ex)))) + val inPred = in(enumSet(int(1), int(3)) ? "I", nextState.ex) + .typed(types, "b") + assertTlaExAndRestore(rewriter, nextState.setRex(inPred)) + val notInPred = not(in(enumSet(int(1), int(2)).typed(types, "I"), nextState.ex) ? "b") + .typed(types, "b") + assertTlaExAndRestore(rewriter, nextState.setRex(notInPred)) } - test("""SE-SET-FILTER[1-2]: \E SUBSET X {1} IN {} = {x \in X : [y \in X |-> TRUE][x]} ~~> $B$k""") { + test("""\E X \in SUBSET {1} IN {} = {x \in X : [y \in X |-> TRUE][x]}""") { // regression - val baseSet = tla.enumSet(tla.int(1)) - val filter = tla.appFun(tla.funDef(tla.bool(true), tla.name("y"), tla.name("X")), tla.name("x")) - val filteredSet = tla.filter(tla.name("x"), tla.name("X"), filter) + val baseSet = enumSet(int(1)) + .typed(types, "I") + val pred = appFun(funDef(bool(true), name("y") ? "i", name("X") ? "I") ? "i_to_b", name("x") ? "i") + .typed(types, "b") + val filteredSet = filter(name("x") ? "i", name("X") ? "II", pred ? "b") + .typed(types, "I") val ex = - OperEx(BmcOper.skolem, tla.exists(tla.name("X"), tla.powSet(baseSet), tla.eql(tla.enumSet(), filteredSet))) + apalacheSkolem(exists(name("X") ? "I", powSet(baseSet) ? "II", eql(enumSet() ? "I", filteredSet) ? "b") ? "b") + .typed(types, "b") val state = new SymbState(ex, arena, Binding()) - val rewriter = new SymbStateRewriterImpl(solverContext, new TrivialTypeFinder()) + val rewriter = new SymbStateRewriterImpl(solverContext) val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { - case membershipEx @ NameEx(name) => + case NameEx(_) => assert(solverContext.sat()) rewriter.push() - val failPreds = nextState.arena.findCellsByType(FailPredT()) - val failureOccurs = tla.or(failPreds.map(_.toNameEx): _*) - solverContext.assertGroundExpr(failureOccurs) - assert(!solverContext.sat()) // no failure should be possible case _ => fail("Unexpected rewriting result") } } - test("""SE-SET-FILTER[1-2]: \E X \in SUBSET {1, 2}: {} = {x \in X : [y \in {1} |-> TRUE][x]} ~~> $B$k""") { + test("""\E X \in SUBSET {1, 2}: {} = {x \in X : [y \in {1} |-> TRUE][x]}""") { // regression - val baseSet = tla.enumSet(tla.int(1)) - val set1 = tla.enumSet(tla.int(1)) - val filter = tla.appFun(tla.funDef(tla.bool(true), tla.name("y"), set1), tla.name("x")) - val filteredSet = tla.filter(tla.name("x"), tla.name("X"), filter) - val set12 = tla.enumSet(tla.int(1), tla.int(2)) + val set1 = enumSet(int(1)) + .typed(types, "I") + val pred = appFun(funDef(bool(true), name("y") ? "i", set1) ? "i_to_b", name("x") ? "i") + .typed(types, "b") + val filteredSet = filter(name("x") ? "i", name("X") ? "I", pred) + .typed(types, "I") + val set12 = enumSet(int(1), int(2)) + .typed(types, "I") val ex = - OperEx(BmcOper.skolem, tla.exists(tla.name("X"), tla.powSet(set12), tla.eql(tla.enumSet(), filteredSet))) + apalacheSkolem(exists(name("X") ? "I", powSet(set12) ? "II", eql(enumSet() ? "I", filteredSet) ? "b") ? "b") + .typed(types, "b") val state = new SymbState(ex, arena, Binding()) - val rewriter = new SymbStateRewriterImpl(solverContext, new TrivialTypeFinder()) + val rewriter = new SymbStateRewriterImpl(solverContext) val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { - case membershipEx @ NameEx(name) => + case NameEx(_) => // the new implementation just returns a default value, as in the classical TLA+ interpretation assert(solverContext.sat()) // the result should be true, although some values may be undefined solverContext.assertGroundExpr(nextState.ex) assert(solverContext.sat()) - /* - // the old implementation with failure predicates - rewriter.push() - val failPreds = nextState.arena.findCellsByType(FailPredT()) - val failureOccurs = tla.or(failPreds.map(_.toNameEx): _*) - solverContext.assertGroundExpr(failureOccurs) - assert(solverContext.sat()) // failure should be possible - */ case _ => fail("Unexpected rewriting result") } } - test("""SE-SET-FILTER[1-2]: 3 \in {x \in {2, 3} : x % 2 = 0} ~~> $B$k""") { - def mkSet(elems: TlaEx*) = OperEx(TlaSetOper.enumSet, elems: _*) - - val set = mkSet(ValEx(TlaInt(2)), ValEx(TlaInt(3))) - val xMod2 = OperEx(TlaArithOper.mod, NameEx("x"), ValEx(TlaInt(2))) - val filter = OperEx(TlaOper.eq, xMod2, ValEx(TlaInt(0))) - val filteredSet = OperEx(TlaSetOper.filter, NameEx("x"), set, filter) - val inFilteredSet = OperEx(TlaSetOper.in, ValEx(TlaInt(3)), filteredSet) + test("""3 \in {x \in {2, 3} : x % 2 = 0}""") { + val set = enumSet(int(2), int(3)) + .typed(types, "I") + val xMod2 = mod(name("x") ? "i", int(2)) + .typed(types, "i") + val pred = eql(xMod2, int(0)) + .typed(types, "b") + val filteredSet = filter(name("x") ? "i", set, pred) + .typed(types, "I") + val inFilteredSet = in(int(3), filteredSet) + .typed(types, "b") val state = new SymbState(inFilteredSet, arena, Binding()) val rewriter = create() @@ -664,7 +680,7 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { solverContext.assertGroundExpr(membershipEx) assert(!solverContext.sat()) rewriter.pop() - solverContext.assertGroundExpr(OperEx(TlaBoolOper.not, membershipEx)) + solverContext.assertGroundExpr(not(membershipEx).typed(types, "b")) assert(solverContext.sat()) case _ => @@ -672,15 +688,18 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { } } - test("""SE-SET-MAP[1-2]: {x / 3: x \in {1,2,3,4}} ~~> $C$k""") { - val set = tla.enumSet(1 to 4 map tla.int: _*) - val mapping = tla.div(tla.name("x"), tla.int(3)) - val mappedSet = tla.map(mapping, tla.name("x"), set) + test("""{ x / 3: x \in {1,2,3,4} }""") { + val set = enumSet(1 to 4 map int: _*) + .typed(types, "I") + val mapping = div(name("x") ? "i", int(3)) + .typed(types, "i") + val mappedSet = map(mapping, name("x") ? "i", set) + .typed(types, "I") val state = new SymbState(mappedSet, arena, Binding()) val nextState = create().rewriteUntilDone(state) nextState.ex match { - case membershipEx @ NameEx(name) => + case NameEx(_) => assert(solverContext.sat()) // membership tests are in the tests below @@ -689,11 +708,15 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { } } - test("""SE-SET-MAP[1-2]: 0 \in {x / 3: x \in {1,2,3,4}} ~~> $B$k""") { - val set = tla.enumSet(1 to 4 map tla.int: _*) - val mapping = tla.div(tla.name("x"), tla.int(3)) - val mappedSet = tla.map(mapping, tla.name("x"), set) - val inMappedSet = tla.in(tla.int(0), mappedSet) + test("""0 \in {x / 3: x \in {1,2,3,4}}""") { + val set = enumSet(1 to 4 map int: _*) + .typed(types, "I") + val mapping = div(name("x") ? "i", int(3)) + .typed(types, "i") + val mappedSet = map(mapping, name("x") ? "i", set) + .typed(types, "I") + val inMappedSet = in(int(0), mappedSet) + .typed(types, "b") val state = new SymbState(inMappedSet, arena, Binding()) val rewriter = create() @@ -704,7 +727,7 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { solverContext.assertGroundExpr(membershipEx) assert(solverContext.sat()) rewriter.pop() - solverContext.assertGroundExpr(OperEx(TlaBoolOper.not, membershipEx)) + solverContext.assertGroundExpr(not(membershipEx ? "b").typed(types, "b")) assert(!solverContext.sat()) case _ => @@ -712,11 +735,15 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { } } - test("""SE-SET-MAP[1-2]: 2 \in {x / 3: x \in {1,2,3,4}} ~~> $B$k""") { - val set = tla.enumSet(1 to 4 map tla.int: _*) - val mapping = tla.div(tla.name("x"), tla.int(3)) - val mappedSet = tla.map(mapping, tla.name("x"), set) - val inMappedSet = tla.in(tla.int(2), mappedSet) + test("""2 \in {x / 3: x \in {1,2,3,4}}""") { + val set = enumSet(1 to 4 map int: _*) + .typed(types, "I") + val mapping = div(name("x") ? "i", int(3)) + .typed(types, "i") + val mappedSet = map(mapping, name("x") ? "i", set) + .typed(types, "I") + val inMappedSet = in(int(2), mappedSet) + .typed(types, "b") val state = new SymbState(inMappedSet, arena, Binding()) val rewriter = create() @@ -727,7 +754,7 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { solverContext.assertGroundExpr(membershipEx) assert(!solverContext.sat()) rewriter.pop() - solverContext.assertGroundExpr(OperEx(TlaBoolOper.not, membershipEx)) + solverContext.assertGroundExpr(not(membershipEx ? "b").typed(types, "b")) assert(solverContext.sat()) case _ => @@ -737,21 +764,26 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { // type inference would reject this test("""error: {x: x \in Int}""") { - val set = ValEx(TlaIntSet) - val mapExpr = tla.name("x") - val mapSet = tla.map(mapExpr, tla.name("x"), set) + val set = intSet().typed(types, "I") + val mapSet = map(name("x") ? "i", name("x") ? "i", set) + .typed(types, "I") val state = new SymbState(mapSet, arena, Binding()) val rewriter = create() - assertThrows[TypeInferenceException](rewriter.rewriteUntilDone(state)) + assertThrows[TlaInputError](rewriter.rewriteUntilDone(state)) } - test("""SE-SET-MAP[1-2]: <<2, true>> \in {<>: x \in {1,2,3}, y \in {FALSE, TRUE}} ~~> $B$k""") { - val set123 = tla.enumSet(1 to 3 map tla.int: _*) - val setBool = tla.enumSet(tla.bool(false), tla.bool(true)) - val mapping = tla.tuple(tla.name("x"), tla.name("y")) - val mappedSet = tla.map(mapping, tla.name("x"), set123, tla.name("y"), setBool) - val inMappedSet = tla.in(tla.tuple(tla.int(2), tla.bool(true)), mappedSet) + test("""<<2, true>> \in {<>: x \in {1,2,3}, y \in {FALSE, TRUE}}""") { + val set123 = enumSet(1 to 3 map int: _*) + .typed(types, "I") + val setBool = enumSet(bool(false), bool(true)) + .typed(types, "B") + val mapping = tuple(name("x") ? "i", name("y") ? "b") + .typed(types, "ib") + val mappedSet = map(mapping, name("x") ? "i", set123, name("y") ? "b", setBool) + .typed(types, "IB") + val inMappedSet = in(tuple(int(2), bool(true)) ? "ib", mappedSet) + .typed(types, "b") val state = new SymbState(inMappedSet, arena, Binding()) val rewriter = create() @@ -762,7 +794,7 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { solverContext.assertGroundExpr(membershipEx) assert(solverContext.sat()) rewriter.pop() - solverContext.assertGroundExpr(OperEx(TlaBoolOper.not, membershipEx)) + solverContext.assertGroundExpr(not(membershipEx ? "b").typed(types, "b")) assert(!solverContext.sat()) case _ => @@ -770,14 +802,19 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { } } - test("""SE-SET-MAP[1-2]: <> \in {<>: x \in {1,2} \ {2}, y \in {FALSE, TRUE}}""") { + test("""<> \in {<>: x \in {1,2} \ {2}, y \in {FALSE, TRUE}}""") { // this expression tests regressions in cached expressions // we express {1, 2} \ {2} as a filter, as set difference is not in KerA+ - val set12minus2 = tla.filter(tla.name("z"), tla.enumSet(tla.int(1), tla.int(2)), tla.eql(tla.name("z"), tla.int(1))) - val setBool = tla.enumSet(tla.bool(false), tla.bool(true)) - val mapping = tla.tuple(tla.name("y")) - val mappedSet = tla.map(mapping, tla.name("x"), set12minus2, tla.name("y"), setBool) - val inMappedSet = tla.in(tla.tuple(tla.bool(true)), mappedSet) + val set12minus2 = filter(name("z") ? "i", enumSet(int(1), int(2)) ? "I", eql(name("z") ? "i", int(1)) ? "b") + .typed(types, "I") + val setBool = enumSet(bool(false), bool(true)) + .typed(types, "B") + val mapping = tuple(name("y") ? "b") + .typed(types + ("(b)" -> TupT1(BoolT1())), "(b)") + val mappedSet = map(mapping, name("x") ? "i", set12minus2 ? "I", name("y") ? "b", setBool) + .typed(types + ("(B)" -> SetT1(TupT1(BoolT1()))), "(B)") + val inMappedSet = in(tuple(bool(true)).typed(TupT1(BoolT1())), mappedSet) + .typed(BoolT1()) val state = new SymbState(inMappedSet, arena, Binding()) val rewriter = create() @@ -788,7 +825,7 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { solverContext.assertGroundExpr(membershipEx) assert(solverContext.sat()) rewriter.pop() - solverContext.assertGroundExpr(OperEx(TlaBoolOper.not, membershipEx)) + solverContext.assertGroundExpr(not(membershipEx ? "b").typed(types, "b")) assert(!solverContext.sat()) case _ => @@ -797,58 +834,64 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { } // Regression for the issue 365: https://github.com/informalsystems/apalache/issues/365 - // This test goes through without a need for a fix. test("""MAP: \E S \in SUBSET { [a: "a", b: 1], [a: "a", b: 2] }: "a" \in { r.a: r \in S }""") { // this test reveals a deep bug in the encoding: SUBSET {[a: 1, b: 1], [a: 1, b: 2]} produces a powerset, // whose elements are sets that refer to the same cells, // namely the cells for the records [a: 1, b: 1] and [a: 1, b: 2]. - // If one record is included in a subset, but the other is not, then the map rule produces a contradicting constraint + // If one record is included in a subset, but the other is not, then the map rule produced a contradicting constraint // for the element "a": it must be in the resulting set, and at the same time it must not be in the resulting set. - val rec1 = tla.enumFun(tla.str("a"), tla.str("a"), tla.str("b"), tla.int(1)) - val rec2 = tla.enumFun(tla.str("a"), tla.str("a"), tla.str("b"), tla.int(2)) - val base = tla.enumSet(rec1, rec2) - val powerset = tla.powSet(base) - val map = tla.map(tla.appFun(tla.name("r"), tla.str("a")), tla.name("r"), tla.name("S")) - val mem = tla.in(tla.str("a"), map) - val exists = OperEx(BmcOper.skolem, tla.exists(tla.name("S"), powerset, mem)) - + val recTypes = + Map("b" -> BoolT1(), "s" -> StrT1(), "S" -> SetT1(StrT1()), "r" -> RecT1("a" -> StrT1(), "b" -> IntT1()), + "R" -> SetT1(RecT1("a" -> StrT1(), "b" -> IntT1())), + "RR" -> SetT1(SetT1(RecT1("a" -> StrT1(), "b" -> IntT1())))) + val rec1 = enumFun(str("a"), str("a"), str("b"), int(1)) + .typed(recTypes, "r") + val rec2 = enumFun(str("a"), str("a"), str("b"), int(2)) + .typed(recTypes, "r") + val base = enumSet(rec1, rec2) + .typed(recTypes, "R") + val powerset = powSet(base) + .typed(recTypes, "RR") + val mapped = map(appFun(name("r") ? "r", str("a")) ? "s", name("r") ? "r", name("S") ? "R") + .typed(recTypes, "S") + val mem = in(str("a"), mapped) + .typed(BoolT1()) + val existsForm = apalacheSkolem(exists(name("S") ? "R", powerset ? "RR", mem) ? "b") + .typed(recTypes, "b") + + // the following test goes through without a need for a fix val rewriter = create() - val state = new SymbState(exists, arena, Binding()) + val state = new SymbState(existsForm, arena, Binding()) assumeTlaEx(rewriter, state) - } - - // Regression for the issue 365: https://github.com/informalsystems/apalache/issues/365 - // This test captures the core of the functional test in `test/tla/Fix365_ExistsSubset3.tla`. - test( - """MAP: \E S \in SUBSET { [a: "a", b: 1], [a: "a", b: 2] }: "a" \in { r.a: r \in S } /\ \A x \in S: x.b = 2""") { - // this tests reveals a deep bug in the encoding: SUBSET {[a: 1, b: 1], [a: 1, b: 2]} produces a powerset, - // whose elements are sets that refer to the same cells, - // namely the cells for the records [a: 1, b: 1] and [a: 1, b: 2]. - // If one record is included in a subset, but the other is not, then the map rule produces a contradicting constraint - // for the element "a": it must be in the resulting set, and at the same time it must be not in the resulting set. - val rec1 = tla.enumFun(tla.str("a"), tla.str("a"), tla.str("b"), tla.int(1)) - val rec2 = tla.enumFun(tla.str("a"), tla.str("a"), tla.str("b"), tla.int(2)) - val base = tla.enumSet(rec1, rec2) - val powerset = tla.powSet(base) - val map = tla.map(tla.appFun(tla.name("r"), tla.str("a")), tla.name("r"), tla.name("S")) - val mem = tla.in(tla.str("a"), map) - val forall = tla.forall(tla.name("x"), tla.name("S"), tla.eql(tla.appFun(tla.name("x"), tla.str("b")), tla.int(2))) - val and = tla.and(mem, forall) - val exists = OperEx(BmcOper.skolem, tla.exists(tla.name("S"), powerset, and)) - val rewriter = create() - val state = new SymbState(exists, arena, Binding()) - assumeTlaEx(rewriter, state) + // the following test captures the core of the functional test in `test/tla/Fix365_ExistsSubset3.tla`, + // which needed a fix + // \E S \in SUBSET { [a: "a", b: 1], [a: "a", b: 2] }: "a" \in { r.a: r \in S } /\ \A x \in S: x.b = 2 + val forallForm = + forall(name("x") ? "r", name("S") ? "R", eql(appFun(name("x") ? "r", str("b")) ? "b", int(2)) ? "b") + .typed(recTypes, "b") + val existsForm2 = apalacheSkolem(exists(name("S") ? "R", powerset, and(mem, forallForm) ? "b") ? "b") + .typed(recTypes, "b") + + // reset the solver and arena + solverContext = new PreproSolverContext(new Z3SolverContext(SolverConfig.default.copy(debug = true))) + arena = Arena.create(solverContext) + val rewriter2 = create() + val state2 = new SymbState(existsForm2, arena, Binding()) + assumeTlaEx(rewriter2, state2) } - test("""SE-SET-CUP[1-2]: {1, 3} \cup {3, 4} = {1, 3, 4}""") { - def mkSet(elems: TlaEx*) = OperEx(TlaSetOper.enumSet, elems: _*) - - val left = mkSet(ValEx(TlaInt(1)), ValEx(TlaInt(3))) - val right = mkSet(ValEx(TlaInt(3)), ValEx(TlaInt(4))) - val expected = mkSet(ValEx(TlaInt(1)), ValEx(TlaInt(3)), ValEx(TlaInt(4))) - val cupSet = OperEx(TlaSetOper.cup, left, right) - val eqExpected = OperEx(TlaOper.eq, cupSet, expected) + test("""{1, 3} \cup {3, 4} = {1, 3, 4}""") { + val left = enumSet(int(1), int(3)) + .typed(types, "I") + val right = enumSet(int(3), int(4)) + .typed(types, "I") + val expected = enumSet(int(1), int(3), int(4)) + .typed(types, "I") + val cupSet = cup(left, right) + .typed(types, "I") + val eqExpected = eql(cupSet, expected) + .typed(types, "b") val state = new SymbState(eqExpected, arena, Binding()) val rewriter = create() @@ -861,7 +904,7 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { solverContext.assertGroundExpr(predEx) assert(solverContext.sat()) rewriter.pop() - solverContext.assertGroundExpr(OperEx(TlaBoolOper.not, predEx)) + solverContext.assertGroundExpr(not(predEx ? "b").typed(types, "b")) assert(!solverContext.sat()) case _ => @@ -869,131 +912,24 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { } } - test("""SE-SUBSETEQ[1-3]: {1, 2} \subseteq {1, 2, 3} ~~> $B$... (true)""") { - val left = tla.enumSet(tla.int(1), tla.int(2)) - val right = tla.enumSet(tla.int(1), tla.int(2), tla.int(3)) - val ex = tla.subseteq(left, right) - val state = new SymbState(ex, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case predEx @ NameEx(name) => - rewriter.push() - solverContext.assertGroundExpr(predEx) - assert(solverContext.sat()) - rewriter.pop() - rewriter.push() - solverContext.assertGroundExpr(tla.not(predEx)) - assertUnsatOrExplain(rewriter, nextState) - - case _ => - fail("Unexpected rewriting result") - } - } - - test("""SE-SUBSETEQ[1-3]: {1, 2, 3} \subseteq {1, 2, 3} ~~> $B$... (true)""") { - val right = tla.enumSet(tla.int(1), tla.int(2), tla.int(3)) - val ex = tla.subseteq(right, right) - val state = new SymbState(ex, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case predEx @ NameEx(name) => - rewriter.push() - solverContext.assertGroundExpr(predEx) - assert(solverContext.sat()) - rewriter.pop() - rewriter.push() - solverContext.assertGroundExpr(tla.not(predEx)) - assertUnsatOrExplain(rewriter, nextState) - - case _ => - fail("Unexpected rewriting result") - } - } - - test("""SE-SUBSETEQ[1-3]: {} \subseteq {1, 2, 3} ~~> $B$... (true)""") { - val right = tla.enumSet(tla.int(1), tla.int(2), tla.int(3)) - // an empty set requires a type annotation - val ex = tla.subseteq(emptySetWithType(IntT()), right) - val state = new SymbState(ex, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case predEx @ NameEx(name) => - rewriter.push() - solverContext.assertGroundExpr(predEx) - assert(solverContext.sat()) - rewriter.pop() - rewriter.push() - solverContext.assertGroundExpr(tla.not(predEx)) - assertUnsatOrExplain(rewriter, nextState) - - case _ => - fail("Unexpected rewriting result") - } - } - - test("""SE-SUBSETEQ[1-3]: {1, 4} \subseteq {1, 2, 3} ~~> $B$... (false)""") { - val left = tla.enumSet(tla.int(1), tla.int(4)) - val right = tla.enumSet(tla.int(1), tla.int(2), tla.int(3)) - val ex = tla.subseteq(left, right) - val state = new SymbState(ex, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case predEx @ NameEx(name) => - rewriter.push() - solverContext.assertGroundExpr(predEx) - assertUnsatOrExplain(rewriter, nextState) - rewriter.pop() - rewriter.push() - solverContext.assertGroundExpr(tla.not(predEx)) - assert(solverContext.sat()) - - case _ => - fail("Unexpected rewriting result") - } - } - - // rewritten by Keramelizer - ignore("""SE-SUPSETEQ[1-3]: {1, 2, 3} \supseteq {1, 2} ~~> $B$... (true)""") { - val left = tla.enumSet(tla.int(1), tla.int(2)) - val right = tla.enumSet(tla.int(1), tla.int(2), tla.int(3)) - val ex = tla.supseteq(right, left) - val state = new SymbState(ex, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case predEx @ NameEx(name) => - rewriter.push() - solverContext.assertGroundExpr(predEx) - assert(solverContext.sat()) - rewriter.pop() - rewriter.push() - solverContext.assertGroundExpr(tla.not(predEx)) - assertUnsatOrExplain(rewriter, nextState) - - case _ => - fail("Unexpected rewriting result") - } - } - - // rewritten by Keramelizer - ignore("""SE-SUPSETEQ[1-3]: {1, 2, 3} \supseteq {1, 2, 3} ~~> $B$... (true)""") { - val right = tla.enumSet(tla.int(1), tla.int(2), tla.int(3)) - val ex = tla.supseteq(right, right) + test("""{1, 2} \subseteq {1, 2, 3} ~~> (true)""") { + val left = enumSet(int(1), int(2)) + .typed(types, "I") + val right = enumSet(int(1), int(2), int(3)) + .typed(types, "I") + val ex = subseteq(left, right) + .typed(types, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { - case predEx @ NameEx(name) => + case predEx @ NameEx(_) => rewriter.push() solverContext.assertGroundExpr(predEx) assert(solverContext.sat()) rewriter.pop() rewriter.push() - solverContext.assertGroundExpr(tla.not(predEx)) + solverContext.assertGroundExpr(not(predEx ? "b").typed(types, "b")) assertUnsatOrExplain(rewriter, nextState) case _ => @@ -1001,22 +937,22 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { } } - // rewritten by Keramelizer - ignore("""SE-SUPSETEQ[1-3]: {1, 2, 3} \supseteq {} ~~> $B$... (true)""") { - val right = tla.enumSet(tla.int(1), tla.int(2), tla.int(3)) - // an empty set requires a type annotation - val ex = tla.supseteq(right, emptySetWithType(IntT())) + test("""{1, 2, 3} \subseteq {1, 2, 3} ~~> (true)""") { + val right = enumSet(int(1), int(2), int(3)) + .typed(types, "I") + val ex = subseteq(right, right) + .typed(types, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { - case predEx @ NameEx(name) => + case predEx @ NameEx(_) => rewriter.push() solverContext.assertGroundExpr(predEx) assert(solverContext.sat()) rewriter.pop() rewriter.push() - solverContext.assertGroundExpr(tla.not(predEx)) + solverContext.assertGroundExpr(not(predEx ? "b").typed(types, "b")) assertUnsatOrExplain(rewriter, nextState) case _ => @@ -1024,79 +960,12 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { } } - // rewritten by Keramelizer - ignore("""SE-SUPSETEQ[1-3]: {1, 2, 3} \supseteq {1, 4} ~~> $B$... (false)""") { - val left = tla.enumSet(tla.int(1), tla.int(4)) - val right = tla.enumSet(tla.int(1), tla.int(2), tla.int(3)) - val ex = tla.supseteq(right, left) - val state = new SymbState(ex, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case predEx @ NameEx(name) => - rewriter.push() - solverContext.assertGroundExpr(predEx) - assertUnsatOrExplain(rewriter, nextState) - rewriter.pop() - rewriter.push() - solverContext.assertGroundExpr(tla.not(predEx)) - assert(solverContext.sat()) - - case _ => - fail("Unexpected rewriting result") - } - } - - // rewritten by Keramelizer - ignore("""SE-SUBSET[1-3]: {1, 2} \subset {1, 2, 3} ~~> $B$... (true)""") { - val left = tla.enumSet(tla.int(1), tla.int(2)) - val right = tla.enumSet(tla.int(1), tla.int(2), tla.int(3)) - val ex = tla.subset(left, right) - val state = new SymbState(ex, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case predEx @ NameEx(name) => - rewriter.push() - solverContext.assertGroundExpr(predEx) - assert(solverContext.sat()) - rewriter.pop() - rewriter.push() - solverContext.assertGroundExpr(tla.not(predEx)) - assertUnsatOrExplain(rewriter, nextState) - - case _ => - fail("Unexpected rewriting result") - } - } - - // rewritten by Keramelizer - ignore("""SE-SUBSET[1-3]: {1, 2, 3} \subset {1, 2, 3} ~~> $B$... (false)""") { - val right = tla.enumSet(tla.int(1), tla.int(2), tla.int(3)) - val ex = tla.subset(right, right) - val state = new SymbState(ex, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case predEx @ NameEx(name) => - rewriter.push() - solverContext.assertGroundExpr(predEx) - assertUnsatOrExplain(rewriter, nextState) - rewriter.pop() - rewriter.push() - solverContext.assertGroundExpr(tla.not(predEx)) - assert(solverContext.sat()) - - case _ => - fail("Unexpected rewriting result") - } - } - - // rewritten by Keramelizer - ignore("""SE-SUBSET[1-3]: {} \subset {1, 2, 3} ~~> TRUE""") { - val right = tla.enumSet(tla.int(1), tla.int(2), tla.int(3)) + test("""{} \subseteq {1, 2, 3} ~~> (true)""") { + val right = enumSet(int(1), int(2), int(3)) + .typed(types, "I") // an empty set requires a type annotation - val ex = tla.subset(emptySetWithType(IntT()), right) + val ex = subseteq(enumSet() ? "I", right) + .typed(types, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) @@ -1107,7 +976,7 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { assert(solverContext.sat()) rewriter.pop() rewriter.push() - solverContext.assertGroundExpr(tla.not(predEx)) + solverContext.assertGroundExpr(not(predEx ? "b").typed(types, "b")) assertUnsatOrExplain(rewriter, nextState) case _ => @@ -1115,126 +984,24 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { } } - // rewritten by Keramelizer - ignore("""SE-SUBSET[1-3]: {1, 4} \subset {1, 2, 3} ~~> FALSE""") { - val left = tla.enumSet(tla.int(1), tla.int(4)) - val right = tla.enumSet(tla.int(1), tla.int(2), tla.int(3)) - val ex = tla.subset(left, right) + test("""{1, 4} \subseteq {1, 2, 3} ~~> (false)""") { + val left = enumSet(int(1), int(4)) + .typed(types, "I") + val right = enumSet(int(1), int(2), int(3)) + .typed(types, "I") + val ex = subseteq(left, right) + .typed(types, "b") val state = new SymbState(ex, arena, Binding()) val rewriter = create() val nextState = rewriter.rewriteUntilDone(state) nextState.ex match { - case predEx @ NameEx(name) => - rewriter.push() - solverContext.assertGroundExpr(predEx) - assertUnsatOrExplain(rewriter, nextState) - rewriter.pop() - rewriter.push() - solverContext.assertGroundExpr(tla.not(predEx)) - assert(solverContext.sat()) - - case _ => - fail("Unexpected rewriting result") - } - } - - // rewritten by Keramelizer - ignore("""type inference error: {1, 3} \subset {{1}, {2}, {3}}""") { - // this test worked in the past but now it reports a type inference error - val left = tla.enumSet(tla.int(1), tla.int(3)) - val right = tla.enumSet(tla.enumSet(tla.int(1)), tla.enumSet(tla.int(2)), tla.enumSet(tla.int(3))) - val ex = tla.subset(left, right) - val state = new SymbState(ex, arena, Binding()) - val rewriter = create() - assertThrows[TypeInferenceError] { - rewriter.rewriteUntilDone(state) - } - } - - // rewritten by Keramelizer - ignore("""SE-SUPSET[1-3]: {1, 2, 3} \supset {1, 2} ~~> $B$... (true)""") { - val left = tla.enumSet(tla.int(1), tla.int(2)) - val right = tla.enumSet(tla.int(1), tla.int(2), tla.int(3)) - val ex = tla.supset(right, left) - val state = new SymbState(ex, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case predEx @ NameEx(name) => - rewriter.push() - solverContext.assertGroundExpr(predEx) - assert(solverContext.sat()) - rewriter.pop() - rewriter.push() - solverContext.assertGroundExpr(tla.not(predEx)) - assertUnsatOrExplain(rewriter, nextState) - - case _ => - fail("Unexpected rewriting result") - } - } - - // rewritten by Keramelizer - ignore("""SE-SUPSET[1-3]: {1, 2, 3} \supset {1, 2, 3} ~~> $B$... (false)""") { - val right = tla.enumSet(tla.int(1), tla.int(2), tla.int(3)) - val ex = tla.supset(right, right) - val state = new SymbState(ex, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case predEx @ NameEx(name) => - rewriter.push() - solverContext.assertGroundExpr(predEx) - assertUnsatOrExplain(rewriter, nextState) - rewriter.pop() - rewriter.push() - solverContext.assertGroundExpr(tla.not(predEx)) - assert(solverContext.sat()) - - case _ => - fail("Unexpected rewriting result") - } - } - - // rewritten by Keramelizer - ignore("""SE-SUPSET[1-3]: {1, 2, 3} \supset {} ~~> TRUE""") { - val right = tla.enumSet(tla.int(1), tla.int(2), tla.int(3)) - // an empty set requires a type annotation - val ex = tla.supset(right, emptySetWithType(IntT())) - val state = new SymbState(ex, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case predEx @ NameEx(name) => - rewriter.push() - solverContext.assertGroundExpr(predEx) - assert(solverContext.sat()) - rewriter.pop() - rewriter.push() - solverContext.assertGroundExpr(tla.not(predEx)) - assertUnsatOrExplain(rewriter, nextState) - - case _ => - fail("Unexpected rewriting result") - } - } - - // rewritten by Keramelizer - ignore("""SE-SUBSET[1-3]: {1, 2, 3} \subset {1, 4} ~~> FALSE""") { - val left = tla.enumSet(tla.int(1), tla.int(4)) - val right = tla.enumSet(tla.int(1), tla.int(2), tla.int(3)) - val ex = tla.subset(right, left) - val state = new SymbState(ex, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - nextState.ex match { - case predEx @ NameEx(name) => + case predEx @ NameEx(_) => rewriter.push() solverContext.assertGroundExpr(predEx) assertUnsatOrExplain(rewriter, nextState) rewriter.pop() rewriter.push() - solverContext.assertGroundExpr(tla.not(predEx)) + solverContext.assertGroundExpr(not(predEx ? "b").typed(types, "b")) assert(solverContext.sat()) case _ => @@ -1242,13 +1009,17 @@ class TestSymbStateRewriterSet extends RewriterBase with TestingPredefs { } } - test("""SE-UNION: UNION {{1, 2}, {2,3}} = {1, 2, 3}""") { - val setOf12 = tla.enumSet(tla.int(1), tla.int(2)) - val setOf23 = tla.enumSet(tla.int(3), tla.int(2)) - val setOf123 = tla.enumSet(tla.int(1), tla.int(2), tla.int(3)) + test("""UNION {{1, 2}, {2,3}} = {1, 2, 3}""") { + val setOf12 = enumSet(int(1), int(2)) + .typed(types, "I") + val setOf23 = enumSet(int(3), int(2)) + .typed(types, "I") + val setOf123 = enumSet(int(1), int(2), int(3)) + .typed(types, "I") // This may seem weird, but since we don't know the type of {}, // it should be equal to the empty set of ints. - val eq = OperEx(TlaOper.eq, tla.union(tla.enumSet(setOf12, setOf23)), setOf123) + val eq = eql(union(enumSet(setOf12, setOf23) ? "II") ? "I", setOf123) + .typed(types, "b") val rewriter = create() val state = new SymbState(eq, arena, Binding()) assertTlaExAndRestore(rewriter, state) diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterStr.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterStr.scala index d1d158adbb..b792546a9c 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterStr.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterStr.scala @@ -1,32 +1,20 @@ package at.forsyte.apalache.tla.bmcmt -import at.forsyte.apalache.tla.lir.convenience.tla -import at.forsyte.apalache.tla.lir.values.TlaStr -import at.forsyte.apalache.tla.lir.{NameEx, ValEx} -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.lir.TypedPredefs._ +import at.forsyte.apalache.tla.lir.convenience.tla._ +import at.forsyte.apalache.tla.lir.{BoolT1, StrT1} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestSymbStateRewriterStr extends RewriterBase { - test("SE-STR-CTOR: \"red\" -> $C$k") { - val state = new SymbState(ValEx(TlaStr("red")), arena, Binding()) - val rewriter = create() - val nextStateRed = rewriter.rewriteUntilDone(state) - nextStateRed.ex match { - case predEx @ NameEx(name) => - assert(solverContext.sat()) - val redEqBlue = tla.eql(tla.str("blue"), tla.str("red")) - val nextStateEq = rewriter.rewriteUntilDone(nextStateRed.setRex(redEqBlue)) - rewriter.push() - solverContext.assertGroundExpr(nextStateEq.ex) - assert(!solverContext.sat()) - rewriter.pop() - solverContext.assertGroundExpr(tla.not(nextStateEq.ex)) - assert(solverContext.sat()) + test(""" rewrite "red" """) { + val string = str("red").typed(StrT1()) + val neq = not(eql(str("red"), str("blue")).typed(BoolT1())) + .typed(BoolT1()) - case _ => - fail("Unexpected rewriting result") - } + val state = new SymbState(neq, arena, Binding()) + val rewriter = create() + assertTlaExAndRestore(rewriter, state) } } diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterTlc.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterTlc.scala index 8d4373e416..3e94df4a42 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterTlc.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterTlc.scala @@ -1,27 +1,24 @@ package at.forsyte.apalache.tla.bmcmt -import at.forsyte.apalache.tla.bmcmt.types.{BoolT, FailPredT} -import at.forsyte.apalache.tla.lir.convenience.tla +import at.forsyte.apalache.tla.lir.TypedPredefs._ +import at.forsyte.apalache.tla.lir.convenience.tla._ import at.forsyte.apalache.tla.lir.oper.TlcOper -import at.forsyte.apalache.tla.lir.{NameEx, OperEx} -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.lir._ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestSymbStateRewriterTlc extends RewriterBase { test("SE-TLC-PRINT: PRINT(...) -> TRUE") { - val print = OperEx(TlcOper.print, tla.int(1), tla.str("hello")) + // Builder does not have a standard method for TLC!Print, as we do not construct internally + val print = OperEx(TlcOper.print, int(1).typed(), str("hello").typed())(Typed(StrT1())) val state = new SymbState(print, arena, Binding()) val rewriter = create() val nextStateRed = rewriter.rewriteUntilDone(state) nextStateRed.ex match { - case predEx @ NameEx(name) => + case NameEx(_) => solverContext.assertGroundExpr(nextStateRed.ex) assert(solverContext.sat()) - val failPreds = state.arena.findCellsByType(FailPredT()) - solverContext.assertGroundExpr(tla.or(failPreds.map(_.toNameEx): _*)) - assert(!solverContext.sat()) // no failures should be registered case _ => fail("Unexpected rewriting result") @@ -29,17 +26,15 @@ class TestSymbStateRewriterTlc extends RewriterBase { } test("SE-TLC-PRINT: PRINTT(...) -> TRUE") { - val print = OperEx(TlcOper.printT, tla.str("hello")) + // Builder does not have a standard method for TLC!PrintT, as we do not construct internally + val print = OperEx(TlcOper.printT, str("hello").typed())(Typed(StrT1())) val state = new SymbState(print, arena, Binding()) val rewriter = create() val nextStateRed = rewriter.rewriteUntilDone(state) nextStateRed.ex match { - case predEx @ NameEx(name) => + case NameEx(_) => solverContext.assertGroundExpr(nextStateRed.ex) assert(solverContext.sat()) - val failPreds = state.arena.findCellsByType(FailPredT()) - solverContext.assertGroundExpr(tla.or(failPreds.map(_.toNameEx): _*)) - assert(!solverContext.sat()) // no failures should be registered case _ => fail("Unexpected rewriting result") @@ -47,17 +42,15 @@ class TestSymbStateRewriterTlc extends RewriterBase { } test("SE-TLC-ASSERT: Assert(TRUE, _) -> reach") { - val assertEx = OperEx(TlcOper.assert, tla.bool(true), tla.str("oops")) + // Builder does not have a standard method for TLC!Assert, as we do not construct internally + val assertEx = OperEx(TlcOper.assert, bool(true).typed(), str("oops").typed())(Typed(BoolT1())) val state = new SymbState(assertEx, arena, Binding()) val rewriter = create() val nextStateRed = rewriter.rewriteUntilDone(state) nextStateRed.ex match { - case predEx @ NameEx(name) => + case NameEx(_) => solverContext.assertGroundExpr(nextStateRed.ex) assert(solverContext.sat()) - val failPreds = nextStateRed.arena.findCellsByType(FailPredT()) - solverContext.assertGroundExpr(tla.or(failPreds.map(_.toNameEx): _*)) - assert(!solverContext.sat()) // no failures should be registered case _ => fail("Unexpected rewriting result") @@ -65,76 +58,15 @@ class TestSymbStateRewriterTlc extends RewriterBase { } test("SE-TLC-ASSERT: Assert(FALSE, _) -> TRUE") { - val assertEx = OperEx(TlcOper.assert, tla.bool(false), tla.str("oops")) + // Builder does not have a standard method for TLC!Assert, as we do not construct internally + val assertEx = OperEx(TlcOper.assert, bool(false).typed(), str("oops").typed())(Typed(BoolT1())) val state = new SymbState(assertEx, arena, Binding()) val rewriter = create() val nextStateRed = rewriter.rewriteUntilDone(state) nextStateRed.ex match { - case predEx @ NameEx(name) => + case NameEx(_) => solverContext.assertGroundExpr(nextStateRed.ex) assert(solverContext.sat()) - val failPreds = nextStateRed.arena.findCellsByType(FailPredT()) - assert(failPreds.length == 1) - solverContext.assertGroundExpr(tla.or(failPreds.map(_.toNameEx): _*)) - assert(solverContext.sat()) // a failure should be registered - solverContext.evalGroundExpr(failPreds.head.toNameEx) == tla.bool(true) - val message = rewriter.findMessage(failPreds.head.id) - assert(message == "Assertion error: oops") - - case _ => - fail("Unexpected rewriting result") - } - } - - // the failure predicates should be refactored - ignore("SE-TLC-ASSERT: IF FALSE THEN Assert(FALSE, _) ELSE TRUE -> TRUE") { - val assertEx = tla.ite(tla.bool(false), OperEx(TlcOper.assert, tla.bool(false), tla.str("oops")), tla.bool(true)) - val state = new SymbState(assertEx, arena, Binding()) - val rewriter = create() - val nextStateRed = rewriter.rewriteUntilDone(state) - nextStateRed.ex match { - case predEx @ NameEx(name) => - solverContext.assertGroundExpr(nextStateRed.ex) - assert(solverContext.sat()) - val failPreds = nextStateRed.arena.findCellsByType(FailPredT()) - assert(failPreds.length == 1) - solverContext.assertGroundExpr(tla.or(failPreds.map(_.toNameEx): _*)) - assert(!solverContext.sat()) // a failure should not be registered - - case _ => - fail("Unexpected rewriting result") - } - } - - // the failure predicates should be refactored - ignore("SE-TLC-ASSERT: x \\/ Assert(FALSE, _) -> depends on x") { - // somewhat surprising, the expected behavior of TLC is to short-circuit the evaluation, - // see Specifying Systems, Sec. 14.2.2, p. 231. - arena = arena.appendCell(BoolT()) - val x = arena.topCell // we use a variable to avoid constant optimizations - val assertEx = tla.or(x.toNameEx, OperEx(TlcOper.assert, tla.bool(false), tla.str("oops"))) - val rewriter = create() - val state = new SymbState(assertEx, arena, Binding()) - val nextStateRed = rewriter.rewriteUntilDone(state) - nextStateRed.ex match { - case predEx @ NameEx(name) => - solverContext.assertGroundExpr(nextStateRed.ex) - val failPreds = nextStateRed.arena.findCellsByType(FailPredT()) - assert(failPreds.length == 1) - rewriter.push() - // x = TRUE => no assertion failure - solverContext.assertGroundExpr(x.toNameEx) // x = TRUE - assert(solverContext.sat()) - solverContext.assertGroundExpr(tla.or(failPreds.map(_.toNameEx): _*)) - assert(!solverContext.sat(), "no assertion failure expected") - rewriter.pop() - rewriter.push() - // x = FALSE => assertion failure - solverContext.assertGroundExpr(tla.not(x.toNameEx)) // x = FALSE - assert(solverContext.sat()) - solverContext.assertGroundExpr(tla.or(failPreds.map(_.toNameEx): _*)) - assert(solverContext.sat(), "assertion failure expected") - rewriter.pop() case _ => fail("Unexpected rewriting result") diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterTuple.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterTuple.scala index f1a27c6b99..0bf433c681 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterTuple.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateRewriterTuple.scala @@ -1,154 +1,97 @@ package at.forsyte.apalache.tla.bmcmt import at.forsyte.apalache.tla.bmcmt.types._ -import at.forsyte.apalache.tla.lir.NameEx -import at.forsyte.apalache.tla.lir.convenience.tla -import at.forsyte.apalache.tla.lir.oper.TlaFunOper -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.lir.{BoolT1, IntT1, NameEx, SetT1, StrT1, TupT1} +import at.forsyte.apalache.tla.lir.convenience.tla._ +import at.forsyte.apalache.tla.lir.TypedPredefs._ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestSymbStateRewriterTuple extends RewriterBase { - test("""SE-TUPLE-CTOR[1-2]: <<1, FALSE, {2}>> ~~> $C$k""") { - val tuple = tla.tuple(tla.int(1), tla.bool(false), tla.enumSet(tla.int(2))) - - val state = new SymbState(tuple, arena, Binding()) - val nextState = create().rewriteUntilDone(state) - nextState.ex match { - case membershipEx @ NameEx(name) => - assert(solverContext.sat()) - val cell = nextState.arena.findCellByName(name) - assert(TupleT(List(IntT(), BoolT(), FinSetT(IntT()))) == cell.cellType) - - case _ => - fail("Unexpected rewriting result") - } + private val types = Map( + "b" -> BoolT1(), + "i" -> IntT1(), + "(i)" -> TupT1(IntT1()), + "I" -> SetT1(IntT1()), + "ib" -> TupT1(IntT1(), BoolT1()), + "ibs" -> TupT1(IntT1(), BoolT1(), StrT1()), + "IB" -> SetT1(TupT1(IntT1(), BoolT1())), + "ibI" -> TupT1(IntT1(), BoolT1(), SetT1(IntT1())) + ) + + test("""<<1, FALSE, {2}>>""") { + val tup = tuple(int(1), bool(false), enumSet(int(2)) ? "I") + .typed(types, "ibI") + + val state = new SymbState(tup, arena, Binding()) + val _ = create().rewriteUntilDone(state) + assert(solverContext.sat()) } - test("""SE-TPL-ACC[1-2]: <<1, FALSE, {2}>>[2] ~~> $C$k equals FALSE""") { - val tuple = tla.tuple(tla.int(1), tla.bool(false), tla.enumSet(tla.int(2))) - val tupleAcc = tla.appFun(tuple, tla.int(2)) - val state = new SymbState(tupleAcc, arena, Binding()) - val nextState = create().rewriteUntilDone(state) - nextState.ex match { - case membershipEx @ NameEx(name) => - assert(solverContext.sat()) - val cell = nextState.arena.findCellByName(name) - cell.cellType match { - case BoolT() => - assert(solverContext.sat()) - solverContext.push() - solverContext.assertGroundExpr(tla.eql(cell.toNameEx, tla.bool(false))) - assert(solverContext.sat()) - solverContext.pop() - solverContext.assertGroundExpr(tla.eql(cell.toNameEx, tla.bool(true))) - assert(!solverContext.sat()) - - // we check the actual contents in the later tests that access elements + test(""" <<1, FALSE, {2}>>[2] returns FALSE""") { + val tup = tuple(int(1), bool(false), enumSet(int(2)) ? "I") + val tupleAcc = appFun(tup ? "ibI", int(2)) + .typed(types, "b") + val resEqFalse = eql(tupleAcc, bool(false)) + .typed(BoolT1()) - case _ => - fail("Expected Boolean type") - } - - case _ => - fail("Unexpected rewriting result") - } + val state = new SymbState(resEqFalse, arena, Binding()) + assertTlaExAndRestore(create(), state) } - test("""SE-TUPLE-CTOR[1-2] in a set: {<<1, FALSE>>, <<2, TRUE>>} ~~> $C$k""") { - val tuple1 = tla.tuple(tla.int(1), tla.bool(false)) - val tuple2 = tla.tuple(tla.int(2), tla.bool(true)) + test("""{<<1, FALSE>>, <<2, TRUE>>} works""") { + val tuple1 = tuple(int(1), bool(false)) + val tuple2 = tuple(int(2), bool(true)) + val tupleSet = enumSet(tuple1 ? "ib", tuple2 ? "ib") + .typed(types, "IB") - val state = new SymbState(tla.enumSet(tuple1, tuple2), arena, Binding()) + val state = new SymbState(tupleSet, arena, Binding()) val nextState = create().rewriteUntilDone(state) - nextState.ex match { - case membershipEx @ NameEx(name) => - assert(solverContext.sat()) - val cell = nextState.arena.findCellByName(name) - assert(FinSetT(TupleT(List(IntT(), BoolT()))) == cell.cellType) - - case _ => - fail("Unexpected rewriting result") - } - } - - test("""type inference error: {<<1, FALSE>>, <<2>>}""") { - val tuple1 = tla.tuple(tla.int(1), tla.bool(false)) - val tuple2 = tla.tuple(tla.int(2)) - - val state = new SymbState(tla.enumSet(tuple1, tuple2), arena, Binding()) - assertThrows[TypeInferenceException] { - create().rewriteUntilDone(state) - fail("Expected a type error") - } - } - - test("""type inference error: {<<1, FALSE>>, <>} ~~> $C$k""") { - val tuple1 = tla.tuple(tla.int(1), tla.bool(false)) - val tuple2 = tla.tuple(tla.bool(true), tla.int(2)) - - val state = new SymbState(tla.enumSet(tuple1, tuple2), arena, Binding()) - assertThrows[TypeInferenceException] { - create().rewriteUntilDone(state) - } + assert(solverContext.sat()) } - test("""SE-TUPLE-EQ: ~(<<2, FALSE>> = <<2, TRUE>>) ~~> $C$k""") { - val tuple1 = tla.tuple(tla.int(2), tla.bool(false)) - val tuple2 = tla.tuple(tla.int(2), tla.bool(true)) - val eq = tla.not(tla.eql(tuple1, tuple2)) + test("""~(<<2, FALSE>> = <<2, TRUE>>)""") { + val tuple1 = tuple(int(2), bool(false)) + val tuple2 = tuple(int(2), bool(true)) + val eq = not(eql(tuple1 ? "ib", tuple2 ? "ib") ? "b") + .typed(types, "b") val rewriter = create() val state = new SymbState(eq, arena, Binding()) assertTlaExAndRestore(rewriter, state) } - test("""SE-TUPLE-EQ: <<2, FALSE>> = <<2, FALSE>> ~~> $C$k""") { - val tuple1 = tla.tuple(tla.int(2), tla.bool(false)) - val tuple2 = tla.tuple(tla.int(2), tla.bool(false)) - val eq = tla.eql(tuple1, tuple2) + test("""<<2, FALSE>> = <<2, FALSE>>""") { + val tuple1 = tuple(int(2), bool(false)) + val tuple2 = tuple(int(2), bool(false)) + val eq = eql(tuple1 ? "ib", tuple2 ? "ib") + .typed(types, "b") val rewriter = create() val state = new SymbState(eq, arena, Binding()) assertTlaExAndRestore(rewriter, state) } - // Keramelizer rewrites \X - ignore("""SE-TUPLE-SET: {<<1, FALSE>>, <<2, FALSE>>, <<1, TRUE>>, <<2, TRUE>> = {1,2} \X {FALSE, TRUE} ~~> $B$k""") { - val set12 = tla.enumSet(1 to 2 map tla.int: _*) - val setBool = tla.enumSet(tla.bool(false), tla.bool(true)) - val prod = tla.times(set12, setBool) - def tup(i: Int, b: Boolean) = tla.tuple(tla.int(i), tla.bool(b)) - val eq = tla.eql(prod, tla.enumSet(tup(1, false), tup(1, true), tup(2, false), tup(2, true))) - - val state = new SymbState(eq, arena, Binding()) - val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - rewriter.push() - solverContext.assertGroundExpr(nextState.ex) - assert(solverContext.sat()) - rewriter.pop() - solverContext.assertGroundExpr(tla.not(nextState.ex)) - assert(!solverContext.sat()) - } - - test("""SE-TUPLE-DOM: DOMAIN <<2, FALSE>> = {1, 2}""") { - val tuple = tla.tuple(tla.int(2), tla.bool(false), tla.str("c")) - val set123 = tla.enumSet(1.to(3) map tla.int: _*) - val eq = tla.eql(tla.dom(tuple), set123) + test("""DOMAIN <<2, FALSE, "c">> = {1, 2, 3}""") { + val tup = tuple(int(2), bool(false), str("c")) + val set123 = enumSet(1.to(3) map int: _*) + val eq = eql(dom(tup ? "ibs") ? "I", set123 ? "I") + .typed(types, "b") val state = new SymbState(eq, arena, Binding()) val rewriter = create() assertTlaExAndRestore(rewriter, state) } - test("""SE-TUPLE-EXCEPT: [ <<1, FALSE>> EXCEPT ![1] = 3 ]""") { - val tuple = tla.tuple(tla.int(1), tla.bool(false)) - val except = tla.except(tuple, tla.tuple(tla.int(1)), tla.int(3)) - val state = new SymbState(except, arena, Binding()) + test("""[ <<1, FALSE>> EXCEPT ![1] = 3 ]""") { + val tup = tuple(int(1), bool(false)) + val newTuple = except(tup ? "ib", tuple(int(1)) ? "(i)", int(3)) + .typed(types, "ib") + val eq = eql(newTuple, tuple(int(3), bool(false)) ? "ib") + .typed(types, "b") + + val state = new SymbState(eq, arena, Binding()) val rewriter = create() - val nextState = rewriter.rewriteUntilDone(state) - val expectedTuple = tla.tuple(tla.int(3), tla.bool(false)) - assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(expectedTuple, nextState.ex))) + assertTlaExAndRestore(rewriter, state.setRex(eq)) } } diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestTrivialTypeFinder.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestTrivialTypeFinder.scala deleted file mode 100644 index 08ae906e80..0000000000 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestTrivialTypeFinder.scala +++ /dev/null @@ -1,884 +0,0 @@ -package at.forsyte.apalache.tla.bmcmt - -import at.forsyte.apalache.tla.bmcmt.types._ -import at.forsyte.apalache.tla.bmcmt.types.eager.TrivialTypeFinder -import at.forsyte.apalache.tla.lir.convenience.tla -import at.forsyte.apalache.tla.lir.oper.TlaOper -import at.forsyte.apalache.tla.lir.values.{TlaIntSet, TlaStrSet} -import at.forsyte.apalache.tla.lir.transformations.standard.IncrementalRenaming -import at.forsyte.apalache.tla.lir._ -import at.forsyte.apalache.tla.lir.UntypedPredefs._ -import org.junit.runner.RunWith -import org.scalatest.junit.JUnitRunner - -import scala.collection.immutable.SortedMap - -// TODO: add tests for TlaSetOper.seqSet - -@RunWith(classOf[JUnitRunner]) -class TestTrivialTypeFinder extends RewriterBase { - test("compute IntT") { - val typeFinder = new TrivialTypeFinder() - val cellType = typeFinder.compute(tla.int(1)) - assert(IntT() == cellType) - } - - test("compute BoolT") { - val typeFinder = new TrivialTypeFinder() - val cellType = typeFinder.compute(tla.bool(false)) - assert(BoolT() == cellType) - } - - test("compute ConstT") { - val typeFinder = new TrivialTypeFinder() - val cellType = typeFinder.compute(tla.str("hello")) - assert(ConstT() == cellType) - } - - test("compute names") { - val typeFinder = new TrivialTypeFinder() - val x = tla.name("x") - typeFinder.reset(Map("x" -> BoolT())) - assert(BoolT() == typeFinder.compute(x)) - typeFinder.reset(Map.empty) - typeFinder.compute(x) - assert(typeFinder.typeErrors.nonEmpty) - } - - test("compute basic operators") { - val typeFinder = new TrivialTypeFinder() - val x = tla.name("x") - val p = tla.name("p") - val set1 = tla.name("S") - val set2 = tla.name("T") - // good cases - assert( - BoolT() == - typeFinder.compute(tla.eql(set1, set2), FinSetT(IntT()), FinSetT(IntT()))) - assert( - BoolT() == - typeFinder.compute(tla.neql(set1, set2), FinSetT(IntT()), FinSetT(IntT()))) - assert( - IntT() == - typeFinder.compute(tla.choose(x, set1, p), IntT(), FinSetT(IntT()), BoolT())) - assert( - IntT() == - typeFinder.compute(tla.choose(x, p), IntT(), BoolT())) - assert( - IntT() == - typeFinder.compute(OperEx(TlaOper.chooseIdiom, set1), FinSetT(IntT()))) - assert( - FinSetT(IntT()) - == typeFinder.compute(tla.label(set1, "lab", "a"), FinSetT(IntT()), ConstT(), ConstT())) - // bad cases - typeFinder.compute(tla.eql(set1, set2), FinSetT(IntT()), FinSetT(BoolT())) - assert(typeFinder.typeErrors.nonEmpty) - - typeFinder.compute(tla.neql(set1, set2), FinSetT(IntT()), FinSetT(BoolT())) - assert(typeFinder.typeErrors.nonEmpty) - assert( - IntT() == - typeFinder.compute(tla.choose(x, set1, p), BoolT(), FinSetT(IntT()), BoolT())) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.choose(x, set1, p), IntT(), IntT(), BoolT()) - assert(typeFinder.typeErrors.nonEmpty) - assert( - IntT() == - typeFinder.compute(tla.choose(x, set1, p), IntT(), FinSetT(IntT()), IntT())) - assert(typeFinder.typeErrors.nonEmpty) - assert(IntT() == typeFinder.compute(tla.choose(x, p), IntT(), IntT())) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(OperEx(TlaOper.chooseIdiom, set1), IntT()) - assert(typeFinder.typeErrors.nonEmpty) - - typeFinder.compute(tla.label(set1, "lab", "a"), FinSetT(IntT()), IntT(), ConstT()) - assert(typeFinder.typeErrors.nonEmpty) - } - - test("compute int to int operators") { - val typeFinder = new TrivialTypeFinder() - val i = tla.name("i") - val j = tla.name("j") - // good cases - assert(IntT() == typeFinder.compute(tla.uminus(i), IntT())) - assert(IntT() == typeFinder.compute(tla.plus(i, j), IntT(), IntT())) - assert(IntT() == typeFinder.compute(tla.minus(i, j), IntT(), IntT())) - assert(IntT() == typeFinder.compute(tla.mult(i, j), IntT(), IntT())) - assert(IntT() == typeFinder.compute(tla.div(i, j), IntT(), IntT())) - assert(IntT() == typeFinder.compute(tla.mod(i, j), IntT(), IntT())) - assert(IntT() == typeFinder.compute(tla.exp(i, j), IntT(), IntT())) - assert(IntT() == typeFinder.compute(tla.sum(i, i, j), IntT(), IntT(), IntT())) - assert(IntT() == typeFinder.compute(tla.prod(i, i, j), IntT(), IntT(), IntT())) - // bad cases - typeFinder.compute(tla.uminus(i), BoolT()) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.plus(i, j), BoolT(), IntT()) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.minus(i, j), BoolT(), IntT()) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.mult(i, j), BoolT(), IntT()) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.div(i, j), BoolT(), IntT()) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.mod(i, j), BoolT(), IntT()) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.exp(i, j), BoolT(), IntT()) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.sum(i, i, j), IntT(), BoolT(), IntT()) - assert(typeFinder.typeErrors.nonEmpty) - } - - test("compute int to bool operators") { - val typeFinder = new TrivialTypeFinder() - val i = tla.name("i") - val j = tla.name("j") - // good cases - assert(BoolT() == typeFinder.compute(tla.lt(i, j), IntT(), IntT())) - assert(BoolT() == typeFinder.compute(tla.gt(i, j), IntT(), IntT())) - assert(BoolT() == typeFinder.compute(tla.le(i, j), IntT(), IntT())) - assert(BoolT() == typeFinder.compute(tla.ge(i, j), IntT(), IntT())) - // bad cases - typeFinder.compute(tla.lt(i, j), IntT(), BoolT()) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.gt(i, j), IntT(), BoolT()) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.le(i, j), IntT(), BoolT()) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.ge(i, j), IntT(), BoolT()) - assert(typeFinder.typeErrors.nonEmpty) - } - - test("compute int to X operators") { - val typeFinder = new TrivialTypeFinder() - val i = tla.name("i") - val j = tla.name("j") - // good cases - assert(FinSetT(IntT()) == typeFinder.compute(tla.dotdot(i, j), IntT(), IntT())) - // bad cases - typeFinder.compute(tla.dotdot(i, j), IntT(), BoolT()) - assert(typeFinder.typeErrors.nonEmpty) - } - - test("compute bool operators") { - val typeFinder = new TrivialTypeFinder() - val x = tla.name("x") - val p = tla.name("p") - val q = tla.name("q") - val S = tla.name("S") - // good cases - assert(BoolT() == typeFinder.compute(tla.not(p), BoolT())) - assert(BoolT() == typeFinder.compute(tla.and(p, q), BoolT(), BoolT())) - assert(BoolT() == typeFinder.compute(tla.or(p, q), BoolT(), BoolT())) - assert(BoolT() == typeFinder.compute(tla.impl(p, q), BoolT(), BoolT())) - assert(BoolT() == typeFinder.compute(tla.equiv(p, q), BoolT(), BoolT())) - assert(BoolT() == typeFinder.compute(tla.forall(x, S, p), IntT(), FinSetT(IntT()), BoolT())) - assert(BoolT() == typeFinder.compute(tla.exists(x, S, p), IntT(), FinSetT(IntT()), BoolT())) - // bad cases - typeFinder.compute(tla.not(p), IntT()) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.and(p, q), IntT(), BoolT()) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.or(p, q), IntT(), BoolT()) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.impl(p, q), IntT(), BoolT()) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.equiv(p, q), IntT(), BoolT()) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.forall(x, S, p), IntT(), FinSetT(ConstT()), BoolT()) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.exists(x, S, p), IntT(), FinSetT(ConstT()), BoolT()) - assert(typeFinder.typeErrors.nonEmpty) - } - - test("compute control operators") { - val typeFinder = new TrivialTypeFinder() - val p = NameEx("p") - val q = NameEx("q") - val x = NameEx("x") - val y = NameEx("y") - val z = NameEx("z") - // good cases - assert( - FinSetT(IntT()) == - typeFinder.compute(tla.ite(p, x, y), BoolT(), FinSetT(IntT()), FinSetT(IntT()))) - val caseEx = tla.caseSplit(p, x, q, y) - assert( - FinSetT(IntT()) == - typeFinder.compute(caseEx, BoolT(), FinSetT(IntT()), BoolT(), FinSetT(IntT()))) - val caseOtherEx = tla.caseOther(z, p, x, q, y) - assert( - FinSetT(IntT()) == - typeFinder.compute(caseOtherEx, FinSetT(IntT()), BoolT(), FinSetT(IntT()), BoolT(), FinSetT(IntT()))) - - val decl = TlaOperDecl("A", List(), tla.plus(tla.int(1), tla.int(2))) - val letIn = tla.letIn(tla.plus(tla.int(1), tla.appDecl(decl)), decl) - assert(IntT() == typeFinder.compute(letIn, IntT())) - // bad cases - typeFinder.compute(tla.ite(p, x, y), IntT(), FinSetT(IntT()), FinSetT(IntT())) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.ite(p, x, y), BoolT(), FinSetT(BoolT()), FinSetT(IntT())) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(caseEx, BoolT(), FinSetT(IntT()), IntT(), FinSetT(IntT())) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(caseEx, BoolT(), FinSetT(IntT()), BoolT(), FinSetT(BoolT())) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(caseOtherEx, FinSetT(BoolT()), BoolT(), FinSetT(IntT()), BoolT(), FinSetT(IntT())) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(caseOtherEx, FinSetT(IntT()), BoolT(), FinSetT(IntT()), IntT(), FinSetT(IntT())) - assert(typeFinder.typeErrors.nonEmpty) - } - - test("compute Boolean set operators") { - val typeFinder = new TrivialTypeFinder() - val int1 = tla.int(1) - val set12 = tla.enumSet(tla.int(1), tla.int(2)) - val set123 = tla.enumSet(tla.int(1), tla.int(2), tla.int(3)) - // good cases - assert( - BoolT() == - typeFinder.compute(tla.in(int1, set12), IntT(), FinSetT(IntT()))) - assert( - BoolT() == - typeFinder.compute(tla.notin(int1, set12), IntT(), FinSetT(IntT()))) - assert( - BoolT() == - typeFinder.compute(tla.subset(set12, set123), FinSetT(IntT()), FinSetT(IntT()))) - assert( - BoolT() == - typeFinder.compute(tla.subseteq(set12, set123), FinSetT(IntT()), FinSetT(IntT()))) - assert( - BoolT() == - typeFinder.compute(tla.supset(set12, set123), FinSetT(IntT()), FinSetT(IntT()))) - assert( - BoolT() == - typeFinder.compute(tla.supseteq(set12, set123), FinSetT(IntT()), FinSetT(IntT()))) - // bad cases - typeFinder.compute(tla.in(set12, set12), FinSetT(IntT()), FinSetT(IntT())) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.notin(set12, set12), FinSetT(IntT()), FinSetT(IntT())) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.subset(int1, set12), IntT(), FinSetT(IntT())) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.subseteq(int1, set12), IntT(), FinSetT(IntT())) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.supset(int1, set12), IntT(), FinSetT(IntT())) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.supseteq(int1, set12), IntT(), FinSetT(IntT())) - assert(typeFinder.typeErrors.nonEmpty) - } - - test("compute set ctor") { - val typeFinder = new TrivialTypeFinder() - val cellType = typeFinder.compute(tla.enumSet(tla.int(1)), IntT()) - assert(FinSetT(IntT()) == cellType) - } - - test("compute set-algebraic operators") { - val typeFinder = new TrivialTypeFinder() - val int1 = tla.int(1) - val set12 = tla.enumSet(tla.int(1), tla.int(2)) - val set123 = tla.enumSet(tla.int(1), tla.int(2), tla.int(3)) - // good cases - assert( - FinSetT(IntT()) == - typeFinder.compute(tla.cup(set12, set123), FinSetT(IntT()), FinSetT(IntT()))) - assert( - FinSetT(IntT()) == - typeFinder.compute(tla.cap(set12, set123), FinSetT(IntT()), FinSetT(IntT()))) - assert( - FinSetT(IntT()) == - typeFinder.compute(tla.setminus(set12, set123), FinSetT(IntT()), FinSetT(IntT()))) - // bad cases - typeFinder.compute(tla.cup(int1, set123), IntT(), FinSetT(IntT())) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.cap(int1, set123), IntT(), FinSetT(IntT())) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.setminus(int1, set123), IntT(), FinSetT(IntT())) - assert(typeFinder.typeErrors.nonEmpty) - } - - test("compute set filter") { - val typeFinder = new TrivialTypeFinder() - val int1 = tla.int(1) - val set12 = tla.enumSet(tla.int(1), tla.int(2)) - // good cases - assert( - FinSetT(IntT()) == - typeFinder.compute(tla.filter(NameEx("x"), set12, tla.bool(true)), IntT(), FinSetT(IntT()), BoolT())) - // bad cases - typeFinder.compute(tla.filter(NameEx("x"), int1, tla.bool(true)), IntT(), IntT(), BoolT()) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.filter(NameEx("x"), set12, int1), IntT(), FinSetT(IntT()), IntT()) - assert(typeFinder.typeErrors.nonEmpty) - } - - test("compute set map") { - val typeFinder = new TrivialTypeFinder() - val int1 = tla.int(1) - val set12 = tla.enumSet(tla.int(1), tla.int(2)) - // good cases - assert( - FinSetT(BoolT()) == - typeFinder.compute(tla.map(tla.bool(true), NameEx("x"), set12), BoolT(), IntT(), FinSetT(IntT()))) - assert( - FinSetT(IntT()) == - typeFinder.compute(tla.map(tla.bool(true), NameEx("x"), set12), IntT(), BoolT(), FinSetT(BoolT()))) - // bad cases - typeFinder.compute(tla.map(tla.bool(true), NameEx("x"), int1), BoolT(), IntT(), IntT()) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.map(tla.bool(true), NameEx("x"), set12, NameEx("y"), int1), BoolT(), IntT(), FinSetT(IntT()), - IntT(), IntT()) - assert(typeFinder.typeErrors.nonEmpty) - } - - test("compute UNION") { - val typeFinder = new TrivialTypeFinder() - val int1 = tla.int(1) - val set12 = tla.enumSet(tla.int(1), tla.int(2)) - val set3 = tla.enumSet(tla.int(3)) - // good cases - assert( - FinSetT(IntT()) == - typeFinder.compute(tla.union(tla.enumSet(set12, set3)), FinSetT(FinSetT(IntT())))) - // bad cases - typeFinder.compute(tla.union(set12), FinSetT(IntT())) - assert(typeFinder.typeErrors.nonEmpty) - } - - test("compute SUBSET") { - val typeFinder = new TrivialTypeFinder() - val int1 = tla.int(1) - val set12 = tla.enumSet(tla.int(1), tla.int(2)) - val set3 = tla.enumSet(tla.int(3)) - - // good cases - assert( - FinSetT(FinSetT(IntT())) == - typeFinder.compute(tla.powSet(set12), FinSetT(IntT()))) - // bad cases - typeFinder.compute(tla.powSet(int1), IntT()) - assert(typeFinder.typeErrors.nonEmpty) - } - - test("compute \\X") { - val typeFinder = new TrivialTypeFinder() - val int1 = tla.int(1) - val set12 = tla.enumSet(tla.int(1), tla.int(2)) - val set3 = tla.enumSet(tla.bool(true)) - - // good cases - assert(FinSetT(TupleT(Seq(IntT(), BoolT()))) == - typeFinder.compute(tla.times(set12, set3), FinSetT(IntT()), FinSetT(BoolT()))) - // bad cases - typeFinder.compute(tla.times(set12, int1), FinSetT(IntT()), IntT()) - assert(typeFinder.typeErrors.nonEmpty) - } - - test("compute [S -> T]") { - val typeFinder = new TrivialTypeFinder() - val int1 = tla.int(1) - val set12 = tla.enumSet(tla.int(1), tla.int(2)) - val boolSet = tla.enumSet(tla.bool(false), tla.bool(true)) - - // good cases - assert(FinSetT(FunT(FinSetT(IntT()), BoolT())) == - typeFinder.compute(tla.funSet(set12, boolSet), FinSetT(IntT()), FinSetT(BoolT()))) - // bad cases - typeFinder.compute(tla.funSet(set12, tla.bool(false)), FinSetT(IntT()), BoolT()) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.funSet(int1, boolSet), IntT(), FinSetT(BoolT())) - assert(typeFinder.typeErrors.nonEmpty) - } - - test("compute [a: S, b: T]") { - val typeFinder = new TrivialTypeFinder() - val int1 = tla.int(1) - val set12 = tla.enumSet(tla.int(1), tla.int(2)) - val boolSet = tla.enumSet(tla.bool(false), tla.bool(true)) - - // good cases - assert(FinSetT(RecordT(SortedMap("a" -> IntT(), "b" -> BoolT()))) == - typeFinder.compute(tla.recSet(tla.str("a"), set12, tla.str("b"), boolSet), ConstT(), FinSetT(IntT()), - ConstT(), FinSetT(BoolT()))) - // bad cases - typeFinder.compute(tla.recSet(tla.str("a"), int1, tla.str("b"), boolSet), ConstT(), IntT(), ConstT(), - FinSetT(BoolT())) - assert(typeFinder.typeErrors.nonEmpty) - } - - test("compute [a |-> 1, b |-> FALSE]") { - val typeFinder = new TrivialTypeFinder() - val rec = tla.enumFun(tla.str("a"), tla.int(1), tla.str("b"), tla.bool(false)) - val cellType = typeFinder.compute(rec, ConstT(), IntT(), ConstT(), BoolT()) - assert(RecordT(SortedMap("a" -> IntT(), "b" -> BoolT())) == cellType) - typeFinder.compute(rec, IntT(), IntT(), ConstT(), BoolT()) - assert(typeFinder.typeErrors.nonEmpty) - } - - test("compute <<1, FALSE>>") { - val typeFinder = new TrivialTypeFinder() - val tup = tla.tuple(tla.int(1), tla.bool(false)) - val cellType = typeFinder.compute(tup, IntT(), BoolT()) - assert(TupleT(Seq(IntT(), BoolT())) == cellType) - // an empty tuple is a sequence, but its type is unknown - assert(SeqT(UnknownT()) == typeFinder.compute(tla.tuple())) - } - - test("compute <<1, FALSE>>[1]") { - val typeFinder = new TrivialTypeFinder() - val tup = tla.tuple(tla.int(1), tla.bool(false)) - val tupleType = TupleT(Seq(IntT(), BoolT())) - assert( - IntT() == - typeFinder.compute(tla.appFun(tup, tla.int(1)), tupleType, IntT())) - assert( - BoolT() == - typeFinder.compute(tla.appFun(tup, tla.int(2)), tupleType, IntT())) - // out-of-range expressions - typeFinder.compute(tla.appFun(tup, tla.int(0)), tupleType, IntT()) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.appFun(tup, tla.int(3)), tupleType, IntT()) - assert(typeFinder.typeErrors.nonEmpty) - // only integer constants are allowed! - typeFinder.compute(tla.appFun(tup, tla.name("j")), tupleType, IntT()) - assert(typeFinder.typeErrors.nonEmpty) - } - - test("compute r.a") { - val typeFinder = new TrivialTypeFinder() - val rec = tla.enumFun(tla.str("a"), tla.int(1), tla.str("b"), tla.bool(false)) - val recType = RecordT(SortedMap("a" -> IntT(), "b" -> BoolT())) - - assert( - IntT() == - typeFinder.compute(tla.appFun(rec, tla.str("a")), recType, ConstT())) - assert( - BoolT() == - typeFinder.compute(tla.appFun(rec, tla.str("b")), recType, ConstT())) - // out-of-range expressions - typeFinder.compute(tla.appFun(rec, tla.str("c")), recType, ConstT()) - assert(typeFinder.typeErrors.nonEmpty) - // only string constants are allowed! - typeFinder.compute(tla.appFun(rec, tla.name("x")), recType, ConstT()) - assert(typeFinder.typeErrors.nonEmpty) - } - - test("compute f[x]") { - val typeFinder = new TrivialTypeFinder() - val fun = tla.name("f") // we know just the name - val funType = FunT(FinSetT(IntT()), BoolT()) - - assert( - BoolT() == - typeFinder.compute(tla.appFun(fun, tla.int(1)), funType, IntT())) - // wrong argument type - typeFinder.compute(tla.appFun(fun, tla.str("c")), funType, ConstT()) - assert(typeFinder.typeErrors.nonEmpty) - } - - test("compute DOMAIN fun") { - val typeFinder = new TrivialTypeFinder() - val fun = tla.name("f") // we know just the name - val funType = FunT(FinSetT(BoolT()), BoolT()) - - assert(FinSetT(BoolT()) == typeFinder.compute(tla.dom(fun), funType)) - // wrong argument type - typeFinder.compute(tla.dom(fun), BoolT()) - assert(typeFinder.typeErrors.nonEmpty) - } - - test("compute DOMAIN rec") { - val typeFinder = new TrivialTypeFinder() - val rec = tla.name("r") // we know just the name - val recType = RecordT(SortedMap("a" -> IntT(), "b" -> BoolT())) - - assert(FinSetT(ConstT()) == typeFinder.compute(tla.dom(rec), recType)) - // wrong argument type - typeFinder.compute(tla.dom(rec), ConstT()) - assert(typeFinder.typeErrors.nonEmpty) - } - - test("compute DOMAIN tuple") { - val typeFinder = new TrivialTypeFinder() - val tup = tla.name("t") // we know just the name - val tupType = TupleT(Seq(IntT(), BoolT())) - - assert(FinSetT(IntT()) == typeFinder.compute(tla.dom(tup), tupType)) - // wrong argument type - typeFinder.compute(tla.dom(tup), IntT()) - assert(typeFinder.typeErrors.nonEmpty) - } - - test("compute [x \\in S |-> e]") { - val typeFinder = new TrivialTypeFinder() - val S = tla.name("S") - val x = tla.name("x") - val T = tla.name("T") - val y = tla.name("y") - val e = tla.name("e") - // good cases - assert(FunT(FinSetT(IntT()), BoolT()) == - typeFinder.compute(tla.funDef(e, x, S), BoolT(), IntT(), FinSetT(IntT()))) - assert(FunT(FinSetT(TupleT(Seq(IntT(), ConstT()))), BoolT()) == - typeFinder.compute(tla.funDef(e, x, S, y, T), BoolT(), IntT(), FinSetT(IntT()), ConstT(), FinSetT(ConstT()))) - // bad cases - assertThrows[TypeException] { - typeFinder.compute(tla.funDef(e, x, S), BoolT(), ConstT(), IntT()) - } - } - - test("compute [f EXCEPT ![e] = g]") { - val typeFinder = new TrivialTypeFinder() - val f = tla.name("f") - val e = tla.name("e") - val g = tla.name("g") - // good cases - val funT = FunT(FinSetT(IntT()), BoolT()) - // TlaFunOper.except expects an index (single-dimensional as well as multi-dimensional) wrapped into a tuple - assert( - funT == - typeFinder.compute(tla.except(f, e, g), funT, TupleT(Seq(IntT())), BoolT())) - val fun2T = FunT(FinSetT(TupleT(Seq(IntT(), ConstT()))), BoolT()) - assert( - fun2T == - typeFinder.compute(tla.except(f, e, g), fun2T, TupleT(Seq(TupleT(Seq(IntT(), ConstT())))), BoolT())) - // bad cases - typeFinder.compute(tla.except(f, e, g), funT, IntT(), BoolT()) - assert(typeFinder.typeErrors.nonEmpty) - // the type of the new value must agree with the function result - typeFinder.compute(tla.except(f, e, g), funT, TupleT(Seq(IntT())), IntT()) - assert(typeFinder.typeErrors.nonEmpty) - } - - test("compute FiniteSet operators") { - val typeFinder = new TrivialTypeFinder() - val S = tla.name("S") - // good cases - assert(BoolT() == typeFinder.compute(tla.isFin(S), FinSetT(IntT()))) - assert(IntT() == typeFinder.compute(tla.card(S), FinSetT(IntT()))) - // bad cases - typeFinder.compute(tla.isFin(S), IntT()) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.card(S), IntT()) - assert(typeFinder.typeErrors.nonEmpty) - } - - test("compute Sequences operators") { - val typeFinder = new TrivialTypeFinder() - val seq = tla.name("seq") - val x = tla.name("x") - // good cases - assert(IntT() == typeFinder.compute(tla.head(seq), SeqT(IntT()))) - assert(SeqT(IntT()) == typeFinder.compute(tla.tail(seq), SeqT(IntT()))) - assert(SeqT(IntT()) == typeFinder.compute(tla.append(seq, x), SeqT(IntT()), IntT())) - assert(SeqT(IntT()) == typeFinder.compute(tla.concat(seq, seq), SeqT(IntT()), SeqT(IntT()))) - assert(IntT() == typeFinder.compute(tla.len(seq), SeqT(IntT()))) - assert(SeqT(BoolT()) == typeFinder.compute(tla.subseq(seq, x, x), SeqT(BoolT()), IntT(), IntT())) - // bad cases - typeFinder.compute(tla.head(seq), IntT()) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.tail(seq), IntT()) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.append(seq, x), SeqT(IntT()), BoolT()) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.append(seq, x), IntT(), IntT()) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.append(seq, x), TupleT(Seq(IntT())), IntT()) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.concat(seq, seq), SeqT(IntT()), SeqT(BoolT())) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.len(seq), IntT()) - assert(typeFinder.typeErrors.nonEmpty) - typeFinder.compute(tla.subseq(seq, x, x), SeqT(BoolT()), BoolT(), IntT()) - assert(typeFinder.typeErrors.nonEmpty) - assertThrows[NotImplementedError] { - typeFinder.compute(tla.selectseq(seq, x), SeqT(IntT()), BoolT()) - } - } - - test("compute TLC operators") { - val typeFinder = new TrivialTypeFinder() - val e = tla.name("e") - val msg = "message" - // good cases: - // We only allow assert to return a Boolean result. - // When you use assert in a non-Boolean expression, provide the tool with a type annotation. - assert(BoolT() == typeFinder.compute(tla.tlcAssert(e, msg), BoolT(), ConstT())) - // bad cases - assert(BoolT() == typeFinder.compute(tla.tlcAssert(e, msg), IntT(), ConstT())) - assert(typeFinder.typeErrors.nonEmpty) - assert(BoolT() == typeFinder.compute(tla.tlcAssert(e, msg), BoolT(), IntT())) - assert(typeFinder.typeErrors.nonEmpty) - } - - test("compute labels") { - val typeFinder = new TrivialTypeFinder() - val e = tla.name("e") - val msg = "message" - // good cases: - // We only allow assert to return a Boolean result. - // When you use assert in a non-Boolean expression, provide the tool with a type annotation. - assert(IntT() == typeFinder.compute(tla.label(tla.int(1), "lab", "a"), IntT(), ConstT(), ConstT())) - assert(BoolT() == typeFinder.compute(tla.label(tla.bool(false), "lab", "a"), BoolT(), ConstT(), ConstT())) - // no bad cases, as it is impossible to construct an ill-typed label - } - - test("inferAndSave variable assignment") { - val typeFinder = new TrivialTypeFinder() - val x = tla.name("x") - val assign = tla.assignPrime(x, tla.int(1)) - assert(typeFinder.inferAndSave(assign).contains(BoolT())) - assert(IntT() == typeFinder.varTypes("x'")) - } - - test("inferAndSave double assignment") { - val typeFinder = new TrivialTypeFinder() - val x = tla.name("x") - val y = tla.name("y") - // double assignment is fine as soon as the types are preserved - val assign = - tla.or( - tla.assignPrime(x, tla.int(1)), - tla.assignPrime(x, tla.int(3)) - ) - assert(typeFinder.inferAndSave(assign).contains(BoolT())) - assert(IntT() == typeFinder.varTypes("x'")) - } - - test("inferAndSave set filter") { - val typeFinder = new TrivialTypeFinder() - val x = tla.name("x") - val filter = tla.filter(x, tla.enumSet(tla.int(3)), tla.bool(true)) - assert(typeFinder.inferAndSave(filter).contains(FinSetT(IntT()))) - assert(IntT() == typeFinder.varTypes("x")) - } - - test("inferAndSave set map") { - val typeFinder = new TrivialTypeFinder() - val x = tla.name("x") - val y = tla.name("y") - val map = tla.map(tla.enumSet(x), x, tla.enumSet(tla.int(3))) - assert(typeFinder.inferAndSave(map).contains(FinSetT(FinSetT(IntT())))) - assert(IntT() == typeFinder.varTypes("x")) - } - - test("inferAndSave exists/forall") { - var typeFinder = new TrivialTypeFinder() - val x = tla.name("x") - val exists = tla.exists(x, tla.enumSet(tla.int(3)), tla.bool(true)) - assert(typeFinder.inferAndSave(exists).contains(BoolT())) - assert(IntT() == typeFinder.varTypes("x")) - typeFinder = new TrivialTypeFinder() - val forall = tla.exists(x, tla.enumSet(tla.int(3)), tla.bool(true)) - assert(typeFinder.inferAndSave(forall).contains(BoolT())) - assert(IntT() == typeFinder.varTypes("x")) - // bad cases - typeFinder = new TrivialTypeFinder() - assert(typeFinder.inferAndSave(tla.exists(x, tla.enumSet(tla.int(3)), tla.int(4))).isEmpty) - assert(typeFinder.typeErrors.nonEmpty) - } - - test("inferAndSave CHOOSE") { - var typeFinder = new TrivialTypeFinder() - val x = tla.name("x") - val choose = tla.choose(x, tla.enumSet(tla.int(3)), tla.bool(true)) - assert(typeFinder.inferAndSave(choose).contains(IntT())) - assert(IntT() == typeFinder.varTypes("x")) - // bad cases - typeFinder = new TrivialTypeFinder() - val badChoose = tla.choose(x, tla.enumSet(tla.int(3)), tla.int(4)) - assert(typeFinder.inferAndSave(badChoose).isEmpty) - assert(typeFinder.typeErrors.nonEmpty) - } - - test("inferAndSave [x \\in S |-> e]") { - val typeFinder = new TrivialTypeFinder() - val x = tla.name("x") - val map = tla.funDef(tla.enumSet(x), x, tla.enumSet(tla.int(1))) - assert(typeFinder.inferAndSave(map).contains(FunT(FinSetT(IntT()), FinSetT(IntT())))) - assert(IntT() == typeFinder.varTypes("x")) - } - - test("inferAndSave [x \\in S, y \\in T |-> e]") { - val typeFinder = new TrivialTypeFinder() - val x = tla.name("x") - val y = tla.name("y") - val map = tla.funDef(tla.enumSet(x), x, tla.enumSet(tla.int(1)), y, tla.enumSet(tla.bool(false))) - assert(typeFinder.inferAndSave(map).contains(FunT(FinSetT(TupleT(Seq(IntT(), BoolT()))), FinSetT(IntT())))) - assert(IntT() == typeFinder.varTypes("x")) - assert(BoolT() == typeFinder.varTypes("y")) - } - - test("inferAndSave LET A == 1 + 2 IN 1 + A") { - val typeFinder = new TrivialTypeFinder() - val decl = TlaOperDecl("A", List(), tla.plus(tla.int(1), tla.int(2))) - val letIn = tla.letIn(tla.plus(tla.int(1), tla.appDecl(decl)), decl) - assert(typeFinder.inferAndSave(letIn).contains(IntT())) - assert(IntT() == typeFinder.varTypes("A")) - } - - test("inferAndSave recFunRef w/o annotation") { - val typeFinder = new TrivialTypeFinder() - val recFunRef = tla.recFunRef() - typeFinder.inferAndSave(recFunRef) - assert(typeFinder.typeErrors.nonEmpty) - } - - test("inferAndSave recFunRef with annotation") { - val typeFinder = new TrivialTypeFinder() - val recFunRef = tla.withType(tla.recFunRef(), tla.funSet(ValEx(TlaIntSet), ValEx(TlaIntSet))) - val tp = typeFinder.inferAndSave(recFunRef) - assert(typeFinder.typeErrors.isEmpty) - assert(tp.contains(FunT(FinSetT(IntT()), IntT()))) - } - - test("inferAndSave recFunDef with annotation") { - val typeFinder = new TrivialTypeFinder() - val recFunRef = tla.withType(tla.recFunRef(), tla.funSet(ValEx(TlaIntSet), ValEx(TlaIntSet))) - val recFunApply = tla.appFun(recFunRef, NameEx("x")) - // f[x \in {1, 2}] == f[x] - val recFun = tla.recFunDef(recFunApply, NameEx("x"), tla.enumSet(tla.int(1), tla.int(2))) - val tp = typeFinder.inferAndSave(recFun) - assert(typeFinder.typeErrors.isEmpty) - assert(tp.contains(FunT(FinSetT(IntT()), IntT()))) - } - - test("inferAndSave type annotation") { - val typeFinder = new TrivialTypeFinder() - val emptySet = tla.enumSet().untyped() - - val annotatedEx = tla.withType(emptySet, tla.enumSet(ValEx(TlaIntSet))).untyped() - assert(typeFinder.inferAndSave(annotatedEx).contains(FinSetT(IntT()))) - // check that the type of the underlying expression has been changed as well - assert(FinSetT(IntT()) == typeFinder.compute(emptySet)) - } - - // regression, see issue #292 - test("error on type annotation inside type annotation") { - val typeFinder = new TrivialTypeFinder() - val ex = tla.enumSet() - val annotatedEx = tla.withType(ex, tla.withType(tla.enumSet(), tla.enumSet(ValEx(TlaIntSet)))) - assertThrows[TypeException] { typeFinder.inferAndSave(annotatedEx) } - } - - // Since the introduction of BmcOper.assign, the old assignments need to be transformed - // into the form \E t \in S: x' = t - test("inferAndSave from the wild") { - import IncrementalRenaming.makeName - val init = tla.declOp(makeName("RenamedInit", 0), - tla.and( - tla.assignPrime( - tla.name("recOne"), - tla.withType( - tla.enumFun( - tla.str("x"), - tla.str("y") - ), - tla.enumFun( - tla.str("x"), - ValEx(TlaStrSet), - tla.str("y"), - ValEx(TlaIntSet) - ) - ) - ), - tla.assignPrime( - tla.name("recTwo"), - tla.withType( - tla.enumFun( - tla.str("x"), - tla.str("x") - ), - tla.enumFun( - tla.str("x"), - ValEx(TlaStrSet), - tla.str("z"), - tla.enumSet(ValEx(TlaIntSet)) - ) - ) - ) - )) - - val next1 = tla.declOp(makeName("RenamedNext", 0), - tla.and( - tla.assignPrime( - tla.name("recOne"), - tla.withType( - tla.enumFun( - tla.str("x"), - tla.str("x") - ), - tla.enumFun( - tla.str("x"), - ValEx(TlaStrSet), - tla.str("y"), - ValEx(TlaIntSet) - ) - ) - ), - tla.assignPrime( - tla.name("recTwo"), - tla.withType( - tla.enumFun( - tla.str("x"), - tla.str("x") - ), - tla.enumFun( - tla.str("x"), - ValEx(TlaStrSet), - tla.str("z"), - tla.enumSet(ValEx(TlaIntSet)) - ) - ) - ) - )) - - val next2 = tla.declOp(makeName("RenamedNext", 1), - tla.and( - tla.assignPrime( - tla.name("recOne"), - tla.withType( - tla.enumFun( - tla.str("x"), - tla.str("x") - ), - tla.enumFun( - tla.str("x"), - ValEx(TlaStrSet), - tla.str("y"), - ValEx(TlaIntSet) - ) - ) - ), - tla.assignPrime( - tla.name("recTwo"), - tla.withType( - tla.enumFun( - tla.str("x"), - tla.str("z") - ), - tla.enumFun( - tla.str("x"), - ValEx(TlaStrSet), - tla.str("z"), - tla.enumSet(ValEx(TlaIntSet)) - ) - ) - ) - )) - - val decls = Seq( - init, - next1, - next2 - ) - - val ttf = new TrivialTypeFinder - - decls foreach { d => - ttf.inferAndSave(d.body) - } - - assert(ttf.typeErrors.isEmpty) - } -} diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/analyses/TestExpansionMarker.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/analyses/TestExpansionMarker.scala index 6d57e66f3d..b0a63de04c 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/analyses/TestExpansionMarker.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/analyses/TestExpansionMarker.scala @@ -1,9 +1,9 @@ package at.forsyte.apalache.tla.bmcmt.analyses +import at.forsyte.apalache.tla.lir.{BoolT1, FunT1, IntT1, SetT1} import at.forsyte.apalache.tla.lir.convenience.tla import at.forsyte.apalache.tla.lir.transformations.impl.TrackerWithListeners -import at.forsyte.apalache.tla.typecheck.TypedPredefs._ -import at.forsyte.apalache.tla.typecheck.{BoolT1, FunT1, IntT1, SetT1} +import at.forsyte.apalache.tla.lir.TypedPredefs._ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{BeforeAndAfterEach, FunSuite} diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/trex/TestFilteredTransitionExecutor.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/trex/TestFilteredTransitionExecutor.scala index 9de0e80a0a..7cf23cddd7 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/trex/TestFilteredTransitionExecutor.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/trex/TestFilteredTransitionExecutor.scala @@ -3,7 +3,6 @@ package at.forsyte.apalache.tla.bmcmt.trex import at.forsyte.apalache.tla.bmcmt.SymbStateRewriterImpl import at.forsyte.apalache.tla.bmcmt.analyses.ExprGradeStoreImpl import at.forsyte.apalache.tla.bmcmt.smt.{SolverConfig, Z3SolverContext} -import at.forsyte.apalache.tla.bmcmt.types.eager.TrivialTypeFinder import at.forsyte.apalache.tla.lir._ import at.forsyte.apalache.tla.lir.convenience.tla import at.forsyte.apalache.tla.lir.UntypedPredefs._ @@ -23,9 +22,8 @@ class TestFilteredTransitionExecutor extends fixture.FunSuite { type FixtureParam = ExecutorContextT override protected def withFixture(test: OneArgTest): Outcome = { - val typeFinder = new TrivialTypeFinder() val solver = new Z3SolverContext(SolverConfig(debug = false, profile = false, randomSeed = 0)) - val rewriter = new SymbStateRewriterImpl(solver, typeFinder, new ExprGradeStoreImpl()) + val rewriter = new SymbStateRewriterImpl(solver, new ExprGradeStoreImpl()) val exeCtx = new IncrementalExecutionContext(rewriter) try { test(exeCtx) diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/trex/TestTransitionExecutorImplWithIncremental.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/trex/TestTransitionExecutorImplWithIncremental.scala index 23229e62b8..f533009322 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/trex/TestTransitionExecutorImplWithIncremental.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/trex/TestTransitionExecutorImplWithIncremental.scala @@ -3,12 +3,9 @@ package at.forsyte.apalache.tla.bmcmt.trex import at.forsyte.apalache.tla.bmcmt.SymbStateRewriterImpl import at.forsyte.apalache.tla.bmcmt.analyses._ import at.forsyte.apalache.tla.bmcmt.smt.{SolverConfig, Z3SolverContext} -import at.forsyte.apalache.tla.bmcmt.types.eager.TrivialTypeFinder -import at.forsyte.apalache.tla.lir._ -import at.forsyte.apalache.tla.lir.convenience.tla import org.junit.runner.RunWith +import org.scalatest.Outcome import org.scalatest.junit.JUnitRunner -import org.scalatest.{Outcome, fixture} /** * The tests for TransitionExecutorImpl that are using IncrementalSnapshot. @@ -19,9 +16,8 @@ import org.scalatest.{Outcome, fixture} class TestTransitionExecutorImplWithIncremental extends AbstractTestTransitionExecutorImpl[IncrementalExecutionContextSnapshot] { override protected def withFixture(test: OneArgTest): Outcome = { - val typeFinder = new TrivialTypeFinder() val solver = new Z3SolverContext(SolverConfig(debug = false, profile = false, randomSeed = 0)) - val rewriter = new SymbStateRewriterImpl(solver, typeFinder, new ExprGradeStoreImpl()) + val rewriter = new SymbStateRewriterImpl(solver, new ExprGradeStoreImpl()) val exeCtx = new IncrementalExecutionContext(rewriter) try { test(exeCtx) diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/trex/TestTransitionExecutorImplWithOffline.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/trex/TestTransitionExecutorImplWithOffline.scala index 90f7958b12..60abbaf4e5 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/trex/TestTransitionExecutorImplWithOffline.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/trex/TestTransitionExecutorImplWithOffline.scala @@ -3,7 +3,6 @@ package at.forsyte.apalache.tla.bmcmt.trex import at.forsyte.apalache.tla.bmcmt.SymbStateRewriterImpl import at.forsyte.apalache.tla.bmcmt.analyses._ import at.forsyte.apalache.tla.bmcmt.smt.{RecordingSolverContext, SolverConfig} -import at.forsyte.apalache.tla.bmcmt.types.eager.TrivialTypeFinder import org.junit.runner.RunWith import org.scalatest.Outcome import org.scalatest.junit.JUnitRunner @@ -17,9 +16,8 @@ import org.scalatest.junit.JUnitRunner class TestTransitionExecutorImplWithOffline extends AbstractTestTransitionExecutorImpl[OfflineExecutionContextSnapshot] { override protected def withFixture(test: OneArgTest): Outcome = { - val typeFinder = new TrivialTypeFinder() val solver = RecordingSolverContext.createZ3(None, SolverConfig(debug = false, profile = false, randomSeed = 0)) - val rewriter = new SymbStateRewriterImpl(solver, typeFinder, new ExprGradeStoreImpl()) + val rewriter = new SymbStateRewriterImpl(solver, new ExprGradeStoreImpl()) val exeCtx = new OfflineExecutionContext(rewriter) try { test(exeCtx) diff --git a/tla-import/src/main/scala/at/forsyte/apalache/io/tlc/config/TlcConfig.scala b/tla-import/src/main/scala/at/forsyte/apalache/io/tlc/config/TlcConfig.scala index faaf8e5301..52bcae1caa 100644 --- a/tla-import/src/main/scala/at/forsyte/apalache/io/tlc/config/TlcConfig.scala +++ b/tla-import/src/main/scala/at/forsyte/apalache/io/tlc/config/TlcConfig.scala @@ -1,10 +1,9 @@ package at.forsyte.apalache.io.tlc.config import at.forsyte.apalache.io.tlc.config.ConfigModelValue.STR_PREFIX -import at.forsyte.apalache.tla.lir.oper.TlaSetOper -import at.forsyte.apalache.tla.lir.{OperEx, TlaEx, ValEx} -import at.forsyte.apalache.tla.lir.values.{TlaBool, TlaInt, TlaStr} -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.lir.convenience.tla +import at.forsyte.apalache.tla.lir.{BoolT1, IntT1, SetT1, StrT1, TlaEx, Typed, VarT1} +import at.forsyte.apalache.tla.lir.TypedPredefs._ import scala.util.parsing.input.NoPosition @@ -60,7 +59,12 @@ object ConfigModelValue { * @param name the name of a model value */ case class ConfigModelValue(name: String) extends ConfigConstExpr { - override def toTlaEx: TlaEx = ValEx(TlaStr(STR_PREFIX + name)) + override def toTlaEx: TlaEx = { + // currently, we use the type Str for all model values. + // In the future, we might want to distinguish between different uninterpreted types. + // See https://github.com/informalsystems/apalache/issues/570 + tla.str(STR_PREFIX + name).typed(StrT1()) + } } /** @@ -69,7 +73,7 @@ case class ConfigModelValue(name: String) extends ConfigConstExpr { * @param num an integer as BigInt */ case class ConfigIntValue(num: BigInt) extends ConfigConstExpr { - override def toTlaEx: TlaEx = ValEx(TlaInt(num)) + override def toTlaEx: TlaEx = tla.bigInt(num).typed(IntT1()) } /** @@ -78,7 +82,7 @@ case class ConfigIntValue(num: BigInt) extends ConfigConstExpr { * @param b a boolean */ case class ConfigBoolValue(b: Boolean) extends ConfigConstExpr { - override def toTlaEx: TlaEx = ValEx(TlaBool(b)) + override def toTlaEx: TlaEx = tla.bool(b).typed(BoolT1()) } /** @@ -87,7 +91,7 @@ case class ConfigBoolValue(b: Boolean) extends ConfigConstExpr { * @param str a string */ case class ConfigStrValue(str: String) extends ConfigConstExpr { - override def toTlaEx: TlaEx = ValEx(TlaStr(str)) + override def toTlaEx: TlaEx = tla.str(str).typed(StrT1()) } /** @@ -96,7 +100,21 @@ case class ConfigStrValue(str: String) extends ConfigConstExpr { * @param elems the set elements, which are constant expression themselves. */ case class ConfigSetValue(elems: ConfigConstExpr*) extends ConfigConstExpr { - override def toTlaEx: TlaEx = OperEx(TlaSetOper.enumSet, elems.map(_.toTlaEx): _*) + override def toTlaEx: TlaEx = { + val setElems = elems.map(_.toTlaEx) + if (setElems.isEmpty) { + // the element type is uknown, introduce a polymorphic type Set(a) + tla.enumSet().typed(SetT1(VarT1(0))) + } else { + // the element type should be unique + val headType = setElems.head.typeTag.asTlaType1() + if (setElems.tail.exists(_.typeTag != Typed(headType))) { + throw new TlcConfigParseError("Set elements have different types: " + setElems.mkString(", "), NoPosition) + } else { + tla.enumSet(setElems: _*).typed(SetT1(headType)) + } + } + } } /** diff --git a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/ConstAndDefRewriter.scala b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/ConstAndDefRewriter.scala index 4e5af86f6d..df706a409a 100644 --- a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/ConstAndDefRewriter.scala +++ b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/ConstAndDefRewriter.scala @@ -1,10 +1,10 @@ package at.forsyte.apalache.tla.pp -import at.forsyte.apalache.tla.lir._ +import at.forsyte.apalache.tla.lir.{OperT1, _} import at.forsyte.apalache.tla.lir.oper._ import at.forsyte.apalache.tla.lir.transformations.standard.{DeclarationSorter, ModuleByExTransformer, ReplaceFixed} import at.forsyte.apalache.tla.lir.transformations.{TlaModuleTransformation, TransformationTracker} -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import TypedPredefs._ import com.typesafe.scalalogging.LazyLogging /** @@ -19,7 +19,7 @@ class ConstAndDefRewriter(tracker: TransformationTracker) extends TlaModuleTrans val overrides = findOverrides(mod.operDeclarations) def transformDef: TlaDecl => TlaDecl = { - case TlaConstDecl(name) if overrides.contains(name) => + case d @ TlaConstDecl(name) if overrides.contains(name) => val overridingDef = overrides(name) if (overridingDef.formalParams.nonEmpty) { val nargs = overridingDef.formalParams.size @@ -27,10 +27,16 @@ class ConstAndDefRewriter(tracker: TransformationTracker) extends TlaModuleTrans logger.error(msg) logger.error(" > If you need support for n-ary CONSTANTS, write a feature request.") throw new OverridingError(msg, overridingDef.body) + } else if (d.typeTag != overridingDef.body.typeTag) { + val msg = s"The types of ${d.name} and ${overridingDef.name} do not match." + logger.error(msg) + throw new OverridingError(msg, overridingDef.body) } else { logger.info(s" > Replaced CONSTANT $name with ${overridingDef.body}") // Safe constructor: cannot be recursive - TlaOperDecl(name, List(), overridingDef.body) + // Instead of a constant, we have an operator now. Use an operator type. + val typeTag = Typed(OperT1(Seq(), d.typeTag.asTlaType1())) + TlaOperDecl(name, List(), overridingDef.body)(typeTag) } case TlaOperDecl(name, dfParams, _) if overrides.contains(name) => @@ -69,13 +75,16 @@ class ConstAndDefRewriter(tracker: TransformationTracker) extends TlaModuleTrans // Importantly, for every constant c, replace NameEx(c) with OperEx(TlaOper.apply, replacement). // This is needed as we distinguish the operator calls from constant and variable use. - def replaceConstWithCall(mod: TlaModule, name: String): TlaModule = { - val xform = ReplaceFixed(tracker)(NameEx(name), OperEx(TlaOper.apply, NameEx(name))) + def replaceConstWithCall(mod: TlaModule, constDecl: TlaConstDecl): TlaModule = { + val tag = constDecl.typeTag + val operTag = Typed(OperT1(Seq(), constDecl.typeTag.asTlaType1())) + val name = constDecl.name + val xform = ReplaceFixed(tracker)(NameEx(name)(tag), OperEx(TlaOper.apply, NameEx(name)(operTag))(tag)) val moduleXform = ModuleByExTransformer(xform) moduleXform(mod) } - val replacedConsts = mod.declarations.collect { case TlaConstDecl(name) if overrides.contains(name) => name } + val replacedConsts = mod.declarations.collect { case d @ TlaConstDecl(name) if overrides.contains(name) => d } val replaced = replacedConsts.foldLeft(sortedModule)(replaceConstWithCall) replaced } diff --git a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/ConstSimplifierBase.scala b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/ConstSimplifierBase.scala index 7598ac0eec..22cb37eec9 100644 --- a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/ConstSimplifierBase.scala +++ b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/ConstSimplifierBase.scala @@ -3,7 +3,6 @@ package at.forsyte.apalache.tla.pp import at.forsyte.apalache.tla.lir._ import at.forsyte.apalache.tla.lir.oper._ import at.forsyte.apalache.tla.lir.values.{TlaBool, TlaInt, TlaStr} -import at.forsyte.apalache.tla.typecheck.{BoolT1, IntT1} import scala.math.BigInt diff --git a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/Desugarer.scala b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/Desugarer.scala index a71568d2d3..a568c12d10 100644 --- a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/Desugarer.scala +++ b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/Desugarer.scala @@ -4,7 +4,8 @@ import at.forsyte.apalache.tla.lir._ import at.forsyte.apalache.tla.lir.convenience._ import at.forsyte.apalache.tla.lir.oper._ import at.forsyte.apalache.tla.lir.transformations.{TlaExTransformation, TransformationTracker} -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.lir.values.{TlaInt, TlaStr} +import TypedPredefs._ import javax.inject.Singleton @@ -14,7 +15,7 @@ import javax.inject.Singleton * @author Igor Konnov */ @Singleton -class Desugarer(tracker: TransformationTracker) extends TlaExTransformation { +class Desugarer(gen: UniqueNameGenerator, tracker: TransformationTracker) extends TlaExTransformation { override def apply(expr: TlaEx): TlaEx = { transform(expr) @@ -28,53 +29,64 @@ class Desugarer(tracker: TransformationTracker) extends TlaExTransformation { case ex @ OperEx(TlaFunOper.except, fun, args @ _*) => val trArgs = args map transform val (accessors, newValues) = TlaOper.deinterleave(trArgs) - val nonSingletons = accessors.collect { case OperEx(TlaFunOper.tuple, lst @ _*) => lst.size > 1 } - if (nonSingletons.isEmpty) { - // only singleton tuples, construct the same EXCEPT, but with transformed fun and args + val isMultidimensional = accessors.exists { case OperEx(TlaFunOper.tuple, lst @ _*) => lst.size > 1 } + if (accessors.length < 2 && !isMultidimensional) { + // the simplest update [ f EXCEPT ![i] = e ] OperEx(TlaFunOper.except, transform(fun) +: trArgs: _*)(ex.typeTag) } else { - // multiple accesses, e.g., ![i][j] = ... + // we have one of the following (or both): + // 1. a multi-dimension index: [f ![i_1]...[i_m] = e] + // 2. multiple indices: [f !a_1 = e_1, ..., !a_n = e_n] expandExcept(transform(fun), accessors, newValues) } case OperEx(TlaActionOper.unchanged, args @ _*) => // flatten all tuples, e.g., convert <> >> to [x, y, z] // construct a tuple for flattenTuplesInUnchanged, the type is bogus, as the tuple will be unpacked - val asTuple = tla.tuple(args.map(transform(_)): _*).untyped() + val transformedArgs = args.map(transform(_)) + val asTuple = tla + .tuple(transformedArgs: _*) + .typed(TupT1(transformedArgs.map(_.typeTag.asTlaType1()): _*)) val flatArgs = flattenTuplesInUnchanged(asTuple) // map every x to x' = x - val eqs = flatArgs map { x => tla.eql(tla.prime(x), x) } + val eqs = flatArgs map { x: TlaEx => + val tt = x.typeTag.asTlaType1() + val xb = tla.fromTlaEx(x) + tla + .eql(tla.prime(xb) ? "x", xb) + .typed(Map("b" -> BoolT1(), "x" -> tt), "b") + } // x' = x /\ y' = y /\ z' = z eqs match { case Seq() => // results from UNCHANGED <<>>, UNCHANGED << << >> >>, etc. - tla.bool(true).untyped() + tla.bool(true).typed() case Seq(one) => - one.untyped() + one case _ => - tla.and(eqs: _*).untyped() + tla.and(eqs: _*).typed(BoolT1()) } case OperEx(TlaOper.eq, OperEx(TlaFunOper.tuple, largs @ _*), OperEx(TlaFunOper.tuple, rargs @ _*)) => // <> = <> // produce pairwise comparison if (largs.length != rargs.length) { - tla.bool(false).untyped() + tla.bool(false).typed() } else { - val eqs = largs.zip(rargs) map { case (l, r) => tla.eql(this(l), this(r)) } - tla.and(eqs: _*).untyped() + val eqs = largs.zip(rargs) map { case (l, r) => tla.eql(this(l), this(r)).typed(BoolT1()) } + tla.and(eqs: _*).typed(BoolT1()) } case OperEx(TlaOper.ne, OperEx(TlaFunOper.tuple, largs @ _*), OperEx(TlaFunOper.tuple, rargs @ _*)) => // <> /= <> // produce pairwise comparison if (largs.length != rargs.length) { - tla.bool(true).untyped() + tla.bool(true).typed() } else { - val neqs = largs.zip(rargs) map { case (l, r) => tla.neql(this(l), this(r)) } - tla.or(neqs: _*).untyped() + val neqs = largs.zip(rargs) map { case (l, r) => tla.neql(this(l), this(r)).typed(BoolT1()) } + tla.or(neqs: _*).typed(BoolT1()) } case ex @ OperEx(TlaSetOper.filter, boundEx, setEx, predEx) => @@ -109,9 +121,9 @@ class Desugarer(tracker: TransformationTracker) extends TlaExTransformation { if (vars.length > 1) { // a function of multiple arguments: Introduce a tuple to contain all these arguments // we will need types in the future, commented out for now - // val pointType = TupT1(vars.map(_.typeTag.asTlaType1()): _*) - val point = tla.tuple(vars: _*).untyped() // future: typed(pointType) - val plane = tla.times(sets: _*).untyped() // future: typed(SetT1(pointType)) + val pointType = TupT1(vars.map(_.typeTag.asTlaType1()): _*) + val point = tla.tuple(vars: _*).typed(pointType) + val plane = tla.times(sets: _*).typed(SetT1(pointType)) // track the modification to point to the first variable and set tracker.hold(vars.head, point) tracker.hold(sets.head, plane) @@ -146,39 +158,73 @@ class Desugarer(tracker: TransformationTracker) extends TlaExTransformation { Seq(ex) } - private def expandExcept(topFun: TlaEx, accessors: Seq[TlaEx], newValues: Seq[TlaEx]): TlaEx = { - // The general case of [f EXCEPT !a_1 = e_1, ..., !a_n = e_n] - // See p. 304 in Specifying Systems. The definition of EXCEPT for multiple accessors is doubly inductive. - // The implementation below does not match the inductive definition from the book when n > 1. - // The fix is planned in issue: https://github.com/informalsystems/apalache/issues/647 - def untuple: PartialFunction[TlaEx, Seq[TlaEx]] = { case OperEx(TlaFunOper.tuple, args @ _*) => - args - } - - def unfoldKey(indicesInPrefix: Seq[TlaEx], indicesInSuffix: Seq[TlaEx], newValue: TlaEx): TlaEx = { - // produce [f[i_1]...[i_m] EXCEPT ![i_m+1] = unfoldKey(...) ] - indicesInSuffix match { - case Nil => newValue // nothing to unfold, just return g - case oneMoreIndex +: otherIndices => - // f[i_1]...[i_m] - val funApp = indicesInPrefix.foldLeft(topFun)((f, i) => tla.appFun(f, i)) - // the recursive call defines another chain of EXCEPTS - val rhs = unfoldKey(indicesInPrefix :+ oneMoreIndex, otherIndices, newValue) - OperEx(TlaFunOper.except, funApp, tla.tuple(oneMoreIndex), rhs) + private def expandExceptOne(topFun: TlaEx, accessor: TlaEx, newValue: TlaEx): Seq[TlaOperDecl] = { + // rewrite [ f EXCEPT ![i_1]...[i_n] = e ] + def rewrite(fun: TlaEx, keys: List[TlaEx]): Seq[TlaOperDecl] = { + keys match { + case Nil => + throw new LirError("Expected at least one key in EXCEPT, found none") + + case hd :: Nil => + val uniqueName = gen.newName() + // LET tmp == [ fun EXCEPT ![i_n] = e ] IN + val funT = fun.typeTag.asTlaType1() // either FunT1, RecT1, TupT1, or SeqT1 + val operT = OperT1(Seq(), funT) + val decl = tla + .declOp(uniqueName, tla.except(fun, tla.tuple(hd).typed(hd.typeTag.asTlaType1()), newValue).typed(funT)) + .typedOperDecl(operT) + Seq(decl) + + case hd :: tl => + // fun[a_i] + val funT = fun.typeTag.asTlaType1() // either FunT1, RecT1, TupT1, or SeqT1 + val operT = OperT1(Seq(), funT) + val (_, resT) = eatFunType(funT, hd) + val nested = tla.appFun(fun, hd).typed(resT) + // produce the expression for: [ fun[a_i] EXCEPT ![a_{i+1}]...[a_n] = e ] + val defs = rewrite(nested, tl) + // LET tmp == [ fun EXCEPT ![a_i] = output ] IN + // tmp() + val uniqueName = gen.newName() + val nestedFun = tla + .appOp(tla.name(defs.last.name) ? "F") + .typed(Map("F" -> operT, "r" -> resT), "r") + val outDef = tla + .declOp(uniqueName, tla.except(fun, tla.tuple(hd).typed(hd.typeTag.asTlaType1()), nestedFun).typed(funT)) + .typedOperDecl(operT) + defs :+ outDef } } - def eachPair(accessor: TlaEx, newValue: TlaEx): (TlaEx, TlaEx) = { - val indices = untuple(accessor) - // ![e_1][e_2]...[e_k] = g becomes ![e_1] = h - val lhs = tla.tuple(indices.head) - // h is computed by unfoldKey - val rhs = unfoldKey(Seq(indices.head), indices.tail, newValue) - (lhs, rhs) + accessor match { + case OperEx(TlaFunOper.tuple, keys @ _*) => + rewrite(topFun, keys.toList) + + case _ => + throw new LirError("Expected a tuple of keys as an accessor in EXCEPT. Found: " + accessor) } - val expandedPairs = accessors.zip(newValues).map((eachPair _).tupled) - val expandedArgs = (TlaOper.interleave _).tupled(expandedPairs.unzip) - OperEx(TlaFunOper.except, topFun +: expandedArgs: _*) + } + + private def expandExcept(fun: TlaEx, accessors: Seq[TlaEx], newValues: Seq[TlaEx]): TlaEx = { + // The general case of [f EXCEPT !a_1 = e_1, ..., !a_n = e_n] + // See p. 304 in Specifying Systems. The definition of EXCEPT for multiple accessors is doubly inductive. + assert(accessors.length == newValues.length) + val uniqueName = gen.newName() + val funT = fun.typeTag.asTlaType1() + // LET tmp == fun IN + val firstDef = tla.declOp(uniqueName, fun).typedOperDecl(OperT1(Seq(), funT)) + + val defs = + accessors.zip(newValues).foldLeft(Seq(firstDef)) { case (defs, (a, e)) => + val last = defs.last + val latest = tla.appOp(NameEx(last.name)(last.typeTag)).typed(funT) + defs ++ expandExceptOne(latest, a, e) + } + + val operT = OperT1(Seq(), funT) + tla + .letIn(tla.appOp(tla.name(defs.last.name) ? "F") ? "f", defs: _*) + .typed(Map("F" -> operT, "f" -> funT), "f") } /** @@ -194,7 +240,8 @@ class Desugarer(tracker: TransformationTracker) extends TlaExTransformation { // variable substitutions for the variables inside the tuples val subs = collectSubstitutions(Map(), boundEx) val newPred = substituteNames(subs, predEx) - Seq(NameEx(boundName), setEx, newPred) + val xtype = boundEx.typeTag.asTlaType1() + Seq(tla.name(boundName).typed(xtype), setEx, newPred) } /** @@ -206,12 +253,59 @@ class Desugarer(tracker: TransformationTracker) extends TlaExTransformation { */ def collapseTuplesInMap(mapEx: TlaEx, args: Seq[TlaEx]): Seq[TlaEx] = { val (boundEs, setEs) = TlaOper.deinterleave(args) - val boundNames = boundEs map mkTupleName // rename tuples into a names, if needed + val boundNames = boundEs map { e => + val tupleName = mkTupleName(e) + NameEx(tupleName)(e.typeTag) + } // rename tuples into a names, if needed // variable substitutions for the variables inside the tuples val subs = boundEs.foldLeft(Map[String, TlaEx]())(collectSubstitutions) val newMapEx = substituteNames(subs, mapEx) // collect the arguments back - newMapEx +: TlaOper.interleave(boundNames.map(NameEx(_)), setEs) + newMapEx +: TlaOper.interleave(boundNames, setEs) + } + + // this looks like a useful utility function. Move it somewhere else? + private def eatFunType(funT: TlaType1, arg: TlaEx): (TlaType1, TlaType1) = { + (funT, arg) match { + case (FunT1(argT, resT), _) => + if (Typed(argT) != arg.typeTag) { + val actualArgType = arg.typeTag.asTlaType1() + throw new TypingException(s"Expected a function argument of type $argT, found $actualArgType") + } else { + (argT, resT) + } + + case (SeqT1(elem), _) => + if (Typed(IntT1()) != arg.typeTag) { + val actualArgType = arg.typeTag.asTlaType1() + throw new TypingException(s"Expected a sequence argument to be an integer, found $actualArgType") + } else { + (IntT1(), elem) + } + + case (tt @ RecT1(fieldTypes), ValEx(TlaStr(key))) => + if (fieldTypes.contains(key)) { + (StrT1(), fieldTypes(key)) + } else { + throw new IllegalArgumentException(s"No key $key in $tt") + } + + case (tt @ RecT1(_), _) => + throw new TypingException(s"Expected a string argument for $tt, found: $arg") + + case (tt @ TupT1(elems @ _*), ValEx(TlaInt(index))) => + if (index > 0 && index <= elems.length) { + (IntT1(), elems(index.toInt - 1)) + } else { + throw new IllegalArgumentException(s"No index $index in $tt") + } + + case (tt @ TupT1(_), _) => + throw new TypingException(s"Expected a string argument for $tt, found: $arg") + + case _ => + throw new TypingException(s"Unexpected type in function application: $arg") + } } private def collectSubstitutions(subs: Map[String, TlaEx], ex: TlaEx): Map[String, TlaEx] = { @@ -222,9 +316,11 @@ class Desugarer(tracker: TransformationTracker) extends TlaExTransformation { val tupleName = mkTupleName(ex) // introduce a name, e.g., x_y_z for <> >> val indices = assignIndicesInTuple(Map(), ex, Seq()) - def indexToTlaEx(index: Seq[Int]): TlaEx = { - index.foldLeft(tla.name(tupleName): TlaEx) { (e, i) => - tla.appFun(e, tla.int(i)) + def indexToTlaEx(indices: Seq[Int]): TlaEx = { + indices.foldLeft(NameEx(tupleName)(ex.typeTag): TlaEx) { case (e, i) => + val index = tla.int(i).typed() + val (_, resT) = eatFunType(e.typeTag.asTlaType1(), index) + tla.appFun(e, index).typed(resT) } } @@ -285,5 +381,7 @@ class Desugarer(tracker: TransformationTracker) extends TlaExTransformation { } object Desugarer { - def apply(tracker: TransformationTracker): Desugarer = new Desugarer(tracker) + def apply(gen: UniqueNameGenerator, tracker: TransformationTracker): Desugarer = { + new Desugarer(gen, tracker) + } } diff --git a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/ExprOptimizer.scala b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/ExprOptimizer.scala index 7b075d341e..1e4a6d61e5 100644 --- a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/ExprOptimizer.scala +++ b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/ExprOptimizer.scala @@ -6,8 +6,7 @@ import at.forsyte.apalache.tla.lir.convenience.tla import at.forsyte.apalache.tla.lir.transformations.standard.{DeepCopy, FlatLanguagePred, ReplaceFixed} import at.forsyte.apalache.tla.lir.transformations.{LanguageWatchdog, TlaExTransformation, TransformationTracker} import at.forsyte.apalache.tla.lir.values.{TlaInt, TlaStr} -import at.forsyte.apalache.tla.typecheck.{BoolT1, IntT1, OperT1, SetT1, TlaType1, TypingException} -import at.forsyte.apalache.tla.typecheck.TypedPredefs._ +import TypedPredefs._ import javax.inject.Singleton import scala.math.BigInt diff --git a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/Keramelizer.scala b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/Keramelizer.scala index 62b4a200cd..f54725a80a 100644 --- a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/Keramelizer.scala +++ b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/Keramelizer.scala @@ -5,8 +5,7 @@ import at.forsyte.apalache.tla.lir.convenience._ import at.forsyte.apalache.tla.lir.oper._ import at.forsyte.apalache.tla.lir.transformations.standard.FlatLanguagePred import at.forsyte.apalache.tla.lir.transformations.{LanguageWatchdog, TlaExTransformation, TransformationTracker} -import at.forsyte.apalache.tla.typecheck.TypedPredefs._ -import at.forsyte.apalache.tla.typecheck.{BoolT1, SetT1, TlaType1, TupT1} +import TypedPredefs._ import javax.inject.Singleton @@ -155,7 +154,7 @@ class Keramelizer(gen: UniqueNameGenerator, tracker: TransformationTracker) private def transformControl: PartialFunction[TlaEx, TlaEx] = { case expr @ OperEx(TlaControlOper.caseWithOther, otherEx, args @ _*) => def decorateWithIf(elseEx: TlaEx, guardAction: (TlaEx, TlaEx)): TlaEx = { - tla.ite(guardAction._1, guardAction._2, elseEx).typed(BoolT1()) + tla.ite(guardAction._1, guardAction._2, elseEx).typed(elseEx.typeTag.asTlaType1()) } // produce a chain of if-then-else expressions diff --git a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/Normalizer.scala b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/Normalizer.scala index aa49ae10b0..66b8ba648c 100644 --- a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/Normalizer.scala +++ b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/Normalizer.scala @@ -1,11 +1,12 @@ package at.forsyte.apalache.tla.pp +import at.forsyte.apalache.tla.lir.TypedPredefs._ +import at.forsyte.apalache.tla.lir.convenience.tla import at.forsyte.apalache.tla.lir.oper._ import at.forsyte.apalache.tla.lir.transformations.standard.{FlatLanguagePred, ReplaceFixed} import at.forsyte.apalache.tla.lir.transformations.{LanguageWatchdog, TlaExTransformation, TransformationTracker} import at.forsyte.apalache.tla.lir.values.TlaBool import at.forsyte.apalache.tla.lir._ -import at.forsyte.apalache.tla.typecheck.{BoolT1, OperT1} import javax.inject.Singleton @@ -26,7 +27,7 @@ class Normalizer(tracker: TransformationTracker) extends TlaExTransformation { private def nnf(neg: Boolean): TlaExTransformation = tracker.trackEx { case ex @ ValEx(TlaBool(b)) => - OperEx(TlaBoolOper.not, ValEx(TlaBool(b ^ neg))(ex.typeTag))(ex.typeTag) + ValEx(TlaBool(b ^ neg))(ex.typeTag) case vex @ ValEx(_) => vex // this may be called when processing a non-Boolean expression diff --git a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/OperAppToLetInDef.scala b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/OperAppToLetInDef.scala index fe324c4900..5b54b3fbc5 100644 --- a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/OperAppToLetInDef.scala +++ b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/OperAppToLetInDef.scala @@ -5,8 +5,7 @@ import at.forsyte.apalache.tla.lir._ import at.forsyte.apalache.tla.lir.convenience._ import at.forsyte.apalache.tla.lir.transformations.standard.IncrementalRenaming import at.forsyte.apalache.tla.lir.transformations.{TlaExTransformation, TlaModuleTransformation, TransformationTracker} -import at.forsyte.apalache.tla.typecheck.{OperT1, TlaType1, TypingException} -import at.forsyte.apalache.tla.typecheck.TypedPredefs._ +import TypedPredefs._ /** * Replaces instances of user-defined operator applications with a LET-IN wrapper. diff --git a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/ParameterNormalizer.scala b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/ParameterNormalizer.scala index 7846a04684..cf21d593ce 100644 --- a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/ParameterNormalizer.scala +++ b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/ParameterNormalizer.scala @@ -1,18 +1,19 @@ package at.forsyte.apalache.tla.pp -import at.forsyte.apalache.tla.lir.oper.TlaOper -import at.forsyte.apalache.tla.lir.transformations.standard.ReplaceFixed import at.forsyte.apalache.tla.lir._ +import at.forsyte.apalache.tla.lir.convenience.tla +import at.forsyte.apalache.tla.lir.transformations.standard.ReplaceFixed import at.forsyte.apalache.tla.lir.transformations.{TlaExTransformation, TlaModuleTransformation, TransformationTracker} -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import TypedPredefs._ /** * Transforms a declaration A(x,y(_,_)) == e * into parameter-normal form, i.e. * A(x,y(_,_)) == LET x_new == x - * y_new(p1, p2) == y(p1,p2) - * IN e[ x_new/x, y_new/y ] + * y_new(p1, p2) == y(p1,p2) + * IN e[ x_new/x, y_new/y ] * This allows us to limit the number of substitutions when inlining A. + * * @param decisionFn Normal-form transformation will only be applied to operator declarations (both top-level and LET-IN), * for which `decisionFn` evaluates to true. Default: returns true for all recursive operators. */ @@ -24,8 +25,9 @@ class ParameterNormalizer( def normalizeDeclaration(decl: TlaOperDecl): TlaOperDecl = { // Since the body may contain LET-IN defined operators, we first normalize all of them val normalizedBody = normalizeInternalLetIn(decl.body) + val paramTypes = extractParamTypes(decl) val newBody = - if (decisionFn(decl)) normalizeParametersInEx(decl.formalParams)(normalizedBody) + if (decisionFn(decl)) normalizeParametersInEx(decl.formalParams, paramTypes)(normalizedBody) else normalizedBody val newDecl = decl.copy(body = newBody) // Copy doesn't preserve .isRecursive! @@ -33,63 +35,104 @@ class ParameterNormalizer( newDecl } + // extract parameter types from the type tag + private def extractParamTypes(decl: TlaOperDecl): Seq[TlaType1] = { + decl.typeTag match { + case Typed(OperT1(paramTypes, _)) => + val typeParamCount = paramTypes.length + val sigParamCount = decl.formalParams.length + if (typeParamCount == sigParamCount) { + paramTypes + } else { + throw new TypingException( + "The signature of operator %s has %d parameters whereas the type tag has %d parameters" + .format(decl.name, sigParamCount, typeParamCount)) + } + + case _ => + throw new TypingException(s"Operator ${decl.name} has an invalid type tag: " + decl.typeTag) + } + } + /** Expression-level LET-IN transformation, applies `mkParamNormalForm` to all LET-IN declarations */ private def normalizeInternalLetIn: TlaExTransformation = tracker.trackEx { case ex @ LetInEx(body, defs @ _*) => val newDefs = defs map { d => normalizeDeclaration(d) } val newBody = normalizeInternalLetIn(body) if (defs == newDefs && body == newBody) ex - else LetInEx(newBody, newDefs: _*) + else LetInEx(newBody, newDefs: _*)(ex.typeTag) case ex @ OperEx(op, args @ _*) => val newArgs = args map normalizeInternalLetIn if (args == newArgs) ex - else OperEx(op, newArgs: _*) + else OperEx(op, newArgs: _*)(ex.typeTag) case ex => ex } /** Iteratively introduces a new operator for each formal parameter */ - private def normalizeParametersInEx(paramNames: List[FormalParam]): TlaExTransformation = tracker.trackEx { ex => - paramNames.foldLeft(ex) { case (partialEx, fParam) => - val paramOperName = nameGenerator.newName() - fParam match { - case SimpleFormalParam(name) => - // We replace all instances of `fParam` with `paramOperName` - // however, since paramOperName is an operator, we have to replace with application - val tr = ReplaceFixed(tracker)( - NameEx(fParam.name), - OperEx(TlaOper.apply, NameEx(paramOperName)) - ) - val replaced = tr(partialEx) - // if fParam is simple, the introduced operator is nullary - val letInDef = TlaOperDecl(paramOperName, List.empty, NameEx(name)) - LetInEx(replaced, letInDef) - - case OperFormalParam(name, arity) => - // We again replace all instances of `fParam` with `paramOperName` - // As both are operators, we don't need to introduce application - val tr = ReplaceFixed(tracker)( - NameEx(fParam.name), - NameEx(paramOperName) - ) - val replaced = tr(partialEx) - - // This time, the introduced operator is not nullary, so we need to invent parameters - val inventedParams = List.fill(arity) { - nameGenerator.newName() - } - // The body is just the operator applied to all the parameters - val newBody = OperEx( - TlaOper.apply, - name +: inventedParams map { NameEx(_) }: _* - ) - val letInDef = TlaOperDecl(paramOperName, inventedParams map SimpleFormalParam, newBody) - LetInEx(replaced, letInDef) + private def normalizeParametersInEx(paramNames: Seq[FormalParam], paramTypes: Seq[TlaType1]): TlaExTransformation = + tracker.trackEx { ex => + paramNames.zip(paramTypes).foldLeft(ex) { case (partialEx, (fParam, fParamType)) => + val paramOperName = nameGenerator.newName() + (fParam, fParamType) match { + case (SimpleFormalParam(name), paramType) => + // case 1: a normal parameter, not a higher-order one. + // We replace all instances of `fParam` with `paramOperName()` + // however, since paramOperName is an operator, we have to replace with application + val types = Map("t" -> OperT1(Seq(), paramType), "p" -> fParamType) + val tr = ReplaceFixed(tracker)( + tla.name(fParam.name).typed(types, "p"), + tla.appOp(tla.name(paramOperName) ? "t").typed(types, "p") + ) + val replaced = tr(partialEx) + // if fParam is simple, the introduced operator is nullary + val letInDef = + tla + .declOp(paramOperName, tla.name(name).typed(types, "p")) + .typedOperDecl(types, "t") + tla + .letIn(replaced, letInDef) + .typed(types, "p") + + case (OperFormalParam(name, arity), paramType) => + // case 2: a higher-order parameter. + // We again replace all instances of `fParam` with `paramOperName` + // As both are operators, we don't need to introduce application + val tr = ReplaceFixed(tracker)( + tla.name(fParam.name).typed(paramType), + tla.name(paramOperName).typed(paramType) + ) + val replaced = tr(partialEx) + + // This time, the introduced operator is not nullary, so we need to introduce fresh parameters + val freshParams = List.fill(arity) { + nameGenerator.newName() + } + + paramType match { + case OperT1(hoParamTypes, resType) => + // The body is just the operator applied to the fresh parameters + val freshParamsWithTypes = freshParams.zip(hoParamTypes).map { case (n, t) => tla.name(n).typed(t) } + val newBody = + tla + .appOp( + tla.name(name).typed(paramType), + freshParamsWithTypes: _* + ) + .typed(resType) + val letInDef = tla + .declOp(paramOperName, newBody, freshParams map SimpleFormalParam: _*) + .typedOperDecl(paramType) + tla.letIn(replaced, letInDef).typed(resType) + + case _ => + throw new TypingException(s"Expected a higher-order parameter $name, found type: $paramType") + } + } } } - } /** Module-level transformation, calls `mkParamNormalForm` on all operator declarations */ def normalizeModule: TlaModuleTransformation = { m => diff --git a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/TlcConfigImporter.scala b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/TlcConfigImporter.scala index a397ec1dc5..8a4b120e2f 100644 --- a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/TlcConfigImporter.scala +++ b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/TlcConfigImporter.scala @@ -1,10 +1,10 @@ package at.forsyte.apalache.tla.pp import at.forsyte.apalache.io.tlc.config._ +import at.forsyte.apalache.tla.lir.TypedPredefs._ import at.forsyte.apalache.tla.lir._ -import at.forsyte.apalache.tla.lir.oper.TlaOper +import at.forsyte.apalache.tla.lir.convenience._ import at.forsyte.apalache.tla.lir.transformations.{TlaModuleTransformation, TransformationTracker} -import at.forsyte.apalache.tla.lir.UntypedPredefs._ import com.typesafe.scalalogging.LazyLogging /** @@ -17,42 +17,81 @@ import com.typesafe.scalalogging.LazyLogging */ class TlcConfigImporter(config: TlcConfig, tracker: TransformationTracker) extends TlaModuleTransformation with LazyLogging { + private val boolOperT = OperT1(Seq(), BoolT1()) + + private def mkBoolName(name: String): TlaEx = { + tla.name(name).typed(BoolT1()) + } + override def apply(mod: TlaModule): TlaModule = { val assignments = config.constAssignments.map { case (param, value) => - TlaOperDecl(ConstAndDefRewriter.OVERRIDE_PREFIX + param, List(), value.toTlaEx) + val valueEx = value.toTlaEx + val operT = OperT1(Seq(), valueEx.typeTag.asTlaType1()) + tla.declOp(ConstAndDefRewriter.OVERRIDE_PREFIX + param, value.toTlaEx).typedOperDecl(operT) } - val operators = Set(mod.declarations.collect { case TlaOperDecl(name, _, _) => - name - }: _*) val replacements = config.constReplacements.map { case (param, value) => - if (operators.contains(value)) - TlaOperDecl(ConstAndDefRewriter.OVERRIDE_PREFIX + param, List(), OperEx(TlaOper.apply, NameEx(value))) - else - TlaOperDecl(ConstAndDefRewriter.OVERRIDE_PREFIX + param, List(), NameEx(value)) + mod.declarations.find(_.name == value) match { + case Some(d: TlaOperDecl) => + if (d.formalParams.isEmpty) { + val tt = d.typeTag.asTlaType1() + assert(tt.isInstanceOf[OperT1]) + val operT = tt.asInstanceOf[OperT1] + val application = tla.appOp(tla.name(value).typed(operT)).typed(operT.res) + tla.declOp(ConstAndDefRewriter.OVERRIDE_PREFIX + param, application).typedOperDecl(operT) + } else { + val nparams = d.formalParams.size + throw new TLCConfigurationError( + s"Met a replacement $param <- $value, where $value is an operator of $nparams parameters") + } + + case Some(d) => + // This is a branch from the old untyped encoding. Does it make sense in the type encoding? + val tt = d.typeTag.asTlaType1() + tla + .declOp(ConstAndDefRewriter.OVERRIDE_PREFIX + param, tla.name(value).typed(tt)) + .typedOperDecl(OperT1(Seq(), tt)) + + case None => + throw new TLCConfigurationError(s"Met a replacement $param <- $value, but $value is not found") + } } val stateConstraints = config.stateConstraints.zipWithIndex.map { case (value, index) => - TlaOperDecl(TlcConfigImporter.STATE_PREFIX + index, List(), NameEx(value)) + tla + .declOp(TlcConfigImporter.STATE_PREFIX + index, mkBoolName(value)) + .typedOperDecl(boolOperT) } val actionConstraints = config.actionConstraints.zipWithIndex.map { case (value, index) => - TlaOperDecl(TlcConfigImporter.ACTION_PREFIX + index, List(), NameEx(value)) + tla + .declOp(TlcConfigImporter.ACTION_PREFIX + index, mkBoolName(value)) + .typedOperDecl(boolOperT) } val invariants = config.invariants.zipWithIndex.map { case (value, index) => - TlaOperDecl(TlcConfigImporter.INVARIANT_PREFIX + index, List(), NameEx(value)) + tla + .declOp(TlcConfigImporter.INVARIANT_PREFIX + index, mkBoolName(value)) + .typedOperDecl(boolOperT) } val temporalProps = config.temporalProps.zipWithIndex.map { case (value, index) => - TlaOperDecl(TlcConfigImporter.TEMPORAL_PREFIX + index, List(), NameEx(value)) + tla + .declOp(TlcConfigImporter.TEMPORAL_PREFIX + index, mkBoolName(value)) + .typedOperDecl(boolOperT) } val behaviorSpec = config.behaviorSpec match { case InitNextSpec(init, next) => List( - TlaOperDecl(TlcConfigImporter.INIT, List(), NameEx(init)), - TlaOperDecl(TlcConfigImporter.NEXT, List(), NameEx(next)) + tla + .declOp(TlcConfigImporter.INIT, mkBoolName(init)) + .typedOperDecl(boolOperT), + tla + .declOp(TlcConfigImporter.NEXT, mkBoolName(next)) + .typedOperDecl(boolOperT) ) case TemporalSpec(name) => List( - TlaOperDecl(TlcConfigImporter.SPEC, List(), NameEx(name)) + tla + .declOp(TlcConfigImporter.SPEC, mkBoolName(name)) + .typedOperDecl(boolOperT) ) case NullSpec() => diff --git a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/Unroller.scala b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/Unroller.scala index 7aac251d90..10a2aae474 100644 --- a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/Unroller.scala +++ b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/Unroller.scala @@ -1,6 +1,6 @@ package at.forsyte.apalache.tla.pp -import at.forsyte.apalache.tla.lir._ +import at.forsyte.apalache.tla.lir.{LetInEx, NameEx, OperEx, TlaDecl, TlaEx, TlaModule, TlaOperDecl, ValEx} import at.forsyte.apalache.tla.lir.aux.{ExceptionOrValue, FailWith, SucceedWith} import at.forsyte.apalache.tla.lir.oper.TlaOper import at.forsyte.apalache.tla.lir.storage.{BodyMap, BodyMapFactory} @@ -47,7 +47,6 @@ class Unroller(nameGenerator: UniqueNameGenerator, tracker: TransformationTracke extends TlaModuleTransformation { import Unroller._ - import at.forsyte.apalache.tla.lir.UntypedPredefs._ // unrollLetIn performs unrolling on all recursive LET-IN defined operators in the expression private def unrollLetIn( @@ -69,10 +68,11 @@ class Unroller(nameGenerator: UniqueNameGenerator, tracker: TransformationTracke if (defs == newDefs && body == newBody) ex else - LetInEx(newBody, newDefs: _*) + LetInEx(newBody, newDefs: _*)(ex.typeTag) + case ex @ OperEx(op, args @ _*) => val newArgs = args map unrollLetIn(bodyMap) - if (args == newArgs) ex else OperEx(op, newArgs: _*) + if (args == newArgs) ex else OperEx(op, newArgs: _*)(ex.typeTag) case ex => ex } @@ -94,11 +94,11 @@ class Unroller(nameGenerator: UniqueNameGenerator, tracker: TransformationTracke } val newBody = transform(body) if (defs == newDefs && body == newBody) ex - else LetInEx(newBody, newDefs: _*) + else LetInEx(newBody, newDefs: _*)(ex.typeTag) case ex @ OperEx(op, args @ _*) => val newArgs = args map replaceWithDefaults(defaultsMap) - if (args == newArgs) ex else OperEx(op, newArgs: _*) + if (args == newArgs) ex else OperEx(op, newArgs: _*)(ex.typeTag) case ex => ex } @@ -171,7 +171,8 @@ class Unroller(nameGenerator: UniqueNameGenerator, tracker: TransformationTracke // Any remaining applications are default-replaced val defaultReplaced = replaceWithDefaultsTr(inlined) // must specifically set .isRecursive to false, so no .copy - TlaOperDecl(name, fparams, defaultReplaced) + TlaOperDecl(name, fparams, defaultReplaced)(d.typeTag) + case d @ TlaOperDecl(_, _, body) => // d.isRecursive = false // Even though the operator is not recursive, it still may define recursive LET-IN operators inside val unrolledLetIn = unrollLetIn(bodyMap)(body) @@ -186,10 +187,12 @@ class Unroller(nameGenerator: UniqueNameGenerator, tracker: TransformationTracke private def allRecursiveLetInOperatorNames(ex: TlaEx): Set[String] = ex match { case OperEx(_, args @ _*) => args.map(allRecursiveLetInOperatorNames).foldLeft(Set.empty[String]) { _ ++ _ } + case LetInEx(body, defs @ _*) => val allInDefs = defs.map(allRecursiveOperatorsInDecl).foldLeft(Set.empty[String]) { _ ++ _ } val allInBody = allRecursiveLetInOperatorNames(body) allInBody ++ allInDefs + case _ => Set.empty[String] } diff --git a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/passes/DesugarerPassImpl.scala b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/passes/DesugarerPassImpl.scala index ec90f4650d..8cb0fbb2fe 100644 --- a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/passes/DesugarerPassImpl.scala +++ b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/passes/DesugarerPassImpl.scala @@ -2,10 +2,10 @@ package at.forsyte.apalache.tla.pp.passes import at.forsyte.apalache.infra.passes.{Pass, PassOptions, TlaModuleMixin} import at.forsyte.apalache.tla.lir.TlaModule -import at.forsyte.apalache.tla.lir.io.{PrettyWriter, TlaWriterFactory} +import at.forsyte.apalache.tla.lir.io.TlaWriterFactory import at.forsyte.apalache.tla.lir.transformations.TransformationTracker import at.forsyte.apalache.tla.lir.transformations.standard._ -import at.forsyte.apalache.tla.pp.Desugarer +import at.forsyte.apalache.tla.pp.{Desugarer, UniqueNameGenerator} import com.google.inject.Inject import com.google.inject.name.Named import com.typesafe.scalalogging.LazyLogging @@ -21,7 +21,7 @@ import java.nio.file.Path * @param nextPass next pass to call */ class DesugarerPassImpl @Inject() ( - val options: PassOptions, tracker: TransformationTracker, writerFactory: TlaWriterFactory, + val options: PassOptions, tracker: TransformationTracker, gen: UniqueNameGenerator, writerFactory: TlaWriterFactory, @Named("AfterDesugarer") nextPass: Pass with TlaModuleMixin ) extends DesugarerPass with LazyLogging { @@ -42,7 +42,7 @@ class DesugarerPassImpl @Inject() ( override def execute(): Boolean = { logger.info(" > Desugaring...") val input = tlaModule.get - val output = ModuleByExTransformer(Desugarer(tracker))(input) + val output = ModuleByExTransformer(Desugarer(gen, tracker))(input) // dump the result of preprocessing val outdir = options.getOrError("io", "outdir").asInstanceOf[Path] diff --git a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/passes/PreproPassImpl.scala b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/passes/PreproPassImpl.scala index ef217a812d..029c580354 100644 --- a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/passes/PreproPassImpl.scala +++ b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/passes/PreproPassImpl.scala @@ -51,7 +51,7 @@ class PreproPassImpl @Inject() ( val transformationSequence: List[(String, TlaModuleTransformation)] = List( ("PrimePropagation", createModuleTransformerForPrimePropagation(varSet)), - ("Desugarer", ModuleByExTransformer(Desugarer(tracker))), + ("Desugarer", ModuleByExTransformer(Desugarer(gen, tracker))), ("UniqueRenamer", renaming.renameInModule), ("Normalizer", ModuleByExTransformer(Normalizer(tracker))), ("Keramelizer", ModuleByExTransformer(Keramelizer(gen, tracker))) diff --git a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/passes/UnrollPassImpl.scala b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/passes/UnrollPassImpl.scala index 06e46f2763..3a333ee13c 100644 --- a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/passes/UnrollPassImpl.scala +++ b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/passes/UnrollPassImpl.scala @@ -7,7 +7,6 @@ import at.forsyte.apalache.tla.lir.TlaModule import at.forsyte.apalache.tla.lir.io.{PrettyWriter, TlaWriterFactory} import at.forsyte.apalache.tla.lir.transformations.TransformationTracker import at.forsyte.apalache.tla.lir.transformations.standard.IncrementalRenaming -import at.forsyte.apalache.tla.lir.UntypedPredefs._ import at.forsyte.apalache.tla.pp.{UniqueNameGenerator, Unroller} import com.google.inject.Inject import com.google.inject.name.Named diff --git a/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestConstAndDefRewriter.scala b/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestConstAndDefRewriter.scala index 1708f699b7..d5c1e2658d 100644 --- a/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestConstAndDefRewriter.scala +++ b/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestConstAndDefRewriter.scala @@ -5,7 +5,7 @@ import at.forsyte.apalache.tla.imp.SanyImporter import at.forsyte.apalache.tla.imp.src.SourceStore import at.forsyte.apalache.tla.lir.convenience._ import at.forsyte.apalache.tla.lir.transformations.impl.IdleTracker -import at.forsyte.apalache.tla.lir.{SimpleFormalParam, TlaOperDecl} +import at.forsyte.apalache.tla.lir.{IntT1, OperT1, SimpleFormalParam, TlaModule, TlaOperDecl, Typed} import at.forsyte.apalache.tla.lir.UntypedPredefs._ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @@ -37,7 +37,19 @@ class TestConstAndDefRewriter extends FunSuite with BeforeAndAfterEach { val (rootName, modules) = sanyImporter.loadFromSource("const", Source.fromString(text)) val root = modules(rootName) - val rewritten = new ConstAndDefRewriter(new IdleTracker())(root) + // we don't want to run a type checker, so we just hack the type of the declaration n + val newDeclarations = + root.declarations match { + case Seq(n, overrideN: TlaOperDecl, rest @ _*) => + val typedN = n.withTag(Typed(IntT1())) + val overrideTag = Typed(OperT1(Seq(), IntT1())) + val typedOverrideN = TlaOperDecl(overrideN.name, List(), overrideN.body.withTag(Typed(IntT1())))(overrideTag) + Seq(typedN, typedOverrideN) ++ rest + } + + val input = new TlaModule(root.name, newDeclarations) + + val rewritten = new ConstAndDefRewriter(new IdleTracker())(input) assert(rewritten.constDeclarations.isEmpty) // no constants anymore assert(rewritten.operDeclarations.size == 2) val expected_n = TlaOperDecl("n", List(), tla.int(10)) diff --git a/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestDesugarer.scala b/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestDesugarer.scala index 4a21b75129..d6ffac86ad 100644 --- a/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestDesugarer.scala +++ b/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestDesugarer.scala @@ -1,322 +1,723 @@ package at.forsyte.apalache.tla.pp -import at.forsyte.apalache.tla.lir.{SimpleFormalParam, TlaEx} +import at.forsyte.apalache.tla.lir.TypedPredefs._ import at.forsyte.apalache.tla.lir.convenience._ -import at.forsyte.apalache.tla.lir.UntypedPredefs._ -import at.forsyte.apalache.tla.lir.transformations.impl.TrackerWithListeners +import at.forsyte.apalache.tla.lir.transformations.impl.IdleTracker +import at.forsyte.apalache.tla.lir.{ + BoolT1, FunT1, IntT1, OperT1, RecT1, SeqT1, SetT1, SimpleFormalParam, StrT1, TlaEx, TlaType1, TupT1 +} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import org.scalatest.{BeforeAndAfterEach, FunSuite} +import scala.collection.immutable.SortedMap + @RunWith(classOf[JUnitRunner]) class TestDesugarer extends FunSuite with BeforeAndAfterEach { - private var desugarer = new Desugarer(TrackerWithListeners()) + private var gen: UniqueNameGenerator = _ + private var desugarer: Desugarer = _ + private val exceptTypes = Map( + "f1" -> FunT1(IntT1(), IntT1()), + "f2" -> FunT1(IntT1(), FunT1(IntT1(), IntT1())), + "f3" -> FunT1(IntT1(), FunT1(IntT1(), FunT1(IntT1(), IntT1()))), + "r" -> RecT1(SortedMap("foo" -> StrT1(), "bar" -> IntT1())), + "or" -> OperT1(Seq(), RecT1(SortedMap("foo" -> StrT1(), "bar" -> IntT1()))), + "fr" -> FunT1(IntT1(), RecT1(SortedMap("foo" -> StrT1(), "bar" -> IntT1()))), + "ofr" -> OperT1(Seq(), FunT1(IntT1(), RecT1(SortedMap("foo" -> StrT1(), "bar" -> IntT1())))), + "ii" -> TupT1(IntT1(), IntT1()), + "si" -> TupT1(StrT1(), IntT1()), + "is" -> TupT1(IntT1(), StrT1()), + "f_si" -> FunT1(IntT1(), TupT1(StrT1(), IntT1())), + "o_si" -> OperT1(Seq(), TupT1(IntT1(), StrT1())), + "o_f_si" -> OperT1(Seq(), FunT1(IntT1(), TupT1(StrT1(), IntT1()))), + "i_to_Qs" -> FunT1(IntT1(), SeqT1(StrT1())), + "o_i_to_Qs" -> OperT1(Seq(), FunT1(IntT1(), SeqT1(StrT1()))), + "Qs" -> SeqT1(StrT1()), + "o_Qs" -> OperT1(Seq(), SeqT1(StrT1())), + "O1" -> OperT1(Seq(), FunT1(IntT1(), IntT1())), + "O2" -> OperT1(Seq(), FunT1(IntT1(), FunT1(IntT1(), IntT1()))), + "O3" -> OperT1(Seq(), FunT1(IntT1(), FunT1(IntT1(), FunT1(IntT1(), IntT1())))), + "i" -> IntT1(), + "i1" -> TupT1(IntT1()), + "i2" -> TupT1(IntT1(), IntT1()), + "i3" -> TupT1(IntT1(), IntT1(), IntT1()), + "s" -> StrT1(), + "s1" -> TupT1(StrT1()) + ) + private val unchangedTypes = Map( + "i" -> IntT1(), + "b" -> BoolT1(), + "et" -> TupT1(), + "b1" -> TupT1(BoolT1()), + "i_b1_2" -> TupT1(IntT1(), TupT1(BoolT1())), + "ib_2" -> TupT1(IntT1(), BoolT1()), + "ibi_3" -> TupT1(IntT1(), BoolT1()), + "i_ib_2_2" -> TupT1(IntT1(), TupT1(IntT1(), BoolT1())), + "f1" -> FunT1(IntT1(), IntT1()) + ) + private val tupleTypes = Map( + "b" -> BoolT1(), + "i" -> IntT1(), + "I" -> SetT1(IntT1()), + "i_to_I" -> FunT1(IntT1(), SetT1(IntT1())), + "ii_to_i" -> FunT1(TupT1(IntT1(), IntT1()), IntT1()), + "ii_2" -> TupT1(IntT1(), IntT1()), + "i_ii_2_2" -> TupT1(IntT1(), TupT1(IntT1(), IntT1())), + "I_II_2_2" -> SetT1(TupT1(IntT1(), TupT1(IntT1(), IntT1()))), + "II_2" -> SetT1(TupT1(IntT1(), IntT1())), + "i_ii_2_2_to_ii_2" -> FunT1(TupT1(IntT1(), TupT1(IntT1(), IntT1())), TupT1(IntT1(), IntT1())) + ) override def beforeEach(): Unit = { - desugarer = new Desugarer(TrackerWithListeners()) + gen = new UniqueNameGenerator() + desugarer = new Desugarer(gen, new IdleTracker()) + } + + // call the operator that returns a function of type stored in exceptTypes(funAlias) and access it with indices + private def callAndAccess(operName: String, funAlias: String, indices: String*): TlaEx = { + def eatFun(tt: TlaType1, key: String): (TlaType1, TlaType1) = { + tt match { + case FunT1(arg, res) => + (arg, res) + + case RecT1(fieldTypes) => + if (fieldTypes.contains(key)) { + (StrT1(), fieldTypes(key)) + } else { + throw new IllegalArgumentException(s"No key $key in $tt") + } + + case TupT1(elems @ _*) => + val intKey = key.toInt + if (intKey > 0 && intKey <= elems.length) { + (IntT1(), elems(intKey)) + } else { + throw new IllegalArgumentException(s"No index $key in $tt") + } + } + } + + val tt = exceptTypes(funAlias) + val operT = OperT1(Seq(), tt) + indices.foldLeft(tla.appOp(tla.name(operName).typed(operT)).typed(tt)) { case (a, n) => + val (argT, resT) = eatFun(a.typeTag.asTlaType1(), n) + tla.appFun(a, tla.name(n).typed(argT)).typed(resT) + } + } + + test("EXCEPT one-dimensional, one index") { + // input: [f EXCEPT ![<>] = e] + val input = + tla + .except(tla.name("f") ? "f1", tla.tuple(tla.name("i") ? "i") ? "i1", tla.name("e") ? "i") + .typed(exceptTypes, "f1") + // output: the same as input + val output = desugarer.transform(input) + assert(output eqTyped input) + } + + test("EXCEPT two-dimensional, one index") { + // input: [f EXCEPT ![i][j] = e] + val input = + tla + .except(tla.name("f") ? "f2", tla.tuple(tla.name("i") ? "i", tla.name("j") ? "i") ? "i2", tla.name("e") ? "i") + .typed(exceptTypes, "f2") + val output = desugarer.transform(input) + // output: series of LET-IN definitions + // LET t_1 == f + // t_2 == [t_1()[i] EXCEPT ![<>] = e] + // t_3 == [t_1() EXCEPT ![<>] = t_2()] + // IN t_3() + val defs = Seq( + tla + .declOp("t_1", tla.name("f") ? "f2") + .typedOperDecl(exceptTypes, "O2"), + tla + .declOp("t_2", + tla.except(callAndAccess("t_1", "f2", "i"), tla.tuple(tla.name("j") ? "i") ? "i1", + tla.name("e") ? "i") ? "f1") + .typedOperDecl(exceptTypes, "O1"), + tla + .declOp("t_3", + tla.except(callAndAccess("t_1", "f1"), tla.tuple(tla.name("i") ? "i") ? "i1", + callAndAccess("t_2", "f1")) ? "f2") + .typedOperDecl(exceptTypes, "O2") + ) + + val expected: TlaEx = + tla + .letIn(callAndAccess("t_3", "f2"), defs: _*) + .typed(exceptTypes, "f2") + + assert(expected eqTyped output) } - test("chain 2 excepts") { + test("EXCEPT two-dimensional, function + record") { + // input: [f EXCEPT ![i].foo = e] + val input = + tla + .except(tla.name("f") ? "fr", tla.tuple(tla.name("i") ? "i", tla.name("foo") ? "s") ? "is", tla.name("e") ? "s") + .typed(exceptTypes, "fr") + val output = desugarer.transform(input) + // output: series of LET-IN definitions + // LET t_1 == f + // t_2 == [t_1()[i] EXCEPT ![<<"foo">>] = e] + // t_3 == [t_1() EXCEPT ![<>] = t_2()] + // IN t_3() + val defs = Seq( + tla + .declOp("t_1", tla.name("f") ? "fr") + .typedOperDecl(exceptTypes, "ofr"), + tla + .declOp("t_2", + tla.except(callAndAccess("t_1", "f2", "i"), tla.tuple(tla.name("foo") ? "s") ? "s1", + tla.name("e") ? "s") ? "r") + .typedOperDecl(exceptTypes, "or"), + tla + .declOp("t_3", + tla.except(callAndAccess("t_1", "r"), tla.tuple(tla.name("i") ? "i") ? "i1", + callAndAccess("t_2", "f1")) ? "fr") + .typedOperDecl(exceptTypes, "ofr") + ) + + val expected: TlaEx = + tla + .letIn(callAndAccess("t_3", "fr"), defs: _*) + .typed(exceptTypes, "fr") + + assert(expected eqTyped output) + } + + test("EXCEPT two-dimensional, function + tuple") { + // input: [f EXCEPT ![i][1] = e] + val input = + tla + .except(tla.name("f") ? "f_si", tla.tuple(tla.name("i") ? "i", tla.int(1)) ? "ii", tla.name("e") ? "s") + .typed(exceptTypes, "f_si") + val output = desugarer.transform(input) + // output: series of LET-IN definitions + // LET t_1 == f + // t_2 == [t_1()[i] EXCEPT ![<<1>>] = e] + // t_3 == [t_1() EXCEPT ![<>] = t_2()] + // IN t_3() + val defs = Seq( + tla + .declOp("t_1", tla.name("f") ? "f_si") + .typedOperDecl(exceptTypes, "o_f_si"), + tla + .declOp("t_2", + tla.except(callAndAccess("t_1", "f_si", "i"), tla.tuple(tla.int(1)) ? "i1", tla.name("e") ? "s") ? "si") + .typedOperDecl(exceptTypes, "o_si"), + tla + .declOp("t_3", + tla.except(callAndAccess("t_1", "si"), tla.tuple(tla.name("i") ? "i") ? "i1", + callAndAccess("t_2", "si")) ? "f_si") + .typedOperDecl(exceptTypes, "o_f_si") + ) + + val expected: TlaEx = + tla + .letIn(callAndAccess("t_3", "f_si"), defs: _*) + .typed(exceptTypes, "f_si") + + assert(expected eqTyped output) + } + + test("EXCEPT two-dimensional, function + sequence") { // input: [f EXCEPT ![i][j] = e] - val highCalories = - tla.except(tla.name("f"), tla.tuple(tla.name("i"), tla.name("j")), tla.name("e")) - val sugarFree = desugarer.transform(highCalories) - // output [ f EXCEPT ![i] = [f[i] EXCEPT ![j] = e] ] + val input = + tla + .except(tla.name("f") ? "i_to_Qs", tla.tuple(tla.name("i") ? "i", tla.name("j") ? "i") ? "ii", + tla.name("e") ? "s") + .typed(exceptTypes, "i_to_Qs") + val output = desugarer.transform(input) + // output: series of LET-IN definitions + // LET t_1 == f + // t_2 == [t_1()[i] EXCEPT ![<>] = e] + // t_3 == [t_1() EXCEPT ![<>] = t_2()] + // IN t_3() + val defs = Seq( + tla + .declOp("t_1", tla.name("f") ? "i_to_Qs") + .typedOperDecl(exceptTypes, "o_i_to_Qs"), + tla + .declOp("t_2", + tla.except(callAndAccess("t_1", "i_to_Qs", "i"), tla.tuple(tla.name("j") ? "s") ? "i1", + tla.name("e") ? "s") ? "Qs") + .typedOperDecl(exceptTypes, "o_Qs"), + tla + .declOp("t_3", + tla.except(callAndAccess("t_1", "Qs"), tla.tuple(tla.name("i") ? "i") ? "i1", + callAndAccess("t_2", "si")) ? "i_to_Qs") + .typedOperDecl(exceptTypes, "o_i_to_Qs") + ) + val expected: TlaEx = - tla.except(tla.name("f"), tla.tuple(tla.name("i")), - tla.except(tla.appFun(tla.name("f"), tla.name("i")), tla.tuple(tla.name("j")), tla.name("e"))) + tla + .letIn(callAndAccess("t_3", "f_si"), defs: _*) + .typed(exceptTypes, "i_to_Qs") - assert(expected == sugarFree) + assert(expected eqTyped output) } - test("chain 3 excepts") { + test("EXCEPT three-dimensional, one index") { // input: [f EXCEPT ![i][j][k] = e] - val highCalories = - tla.except(tla.name("f"), tla.tuple(tla.name("i"), tla.name("j"), tla.name("k")), tla.name("e")) - val sugarFree = desugarer.transform(highCalories) - // output: [ f EXCEPT ![i] = [f[i] EXCEPT ![j] = [f[i][j] EXCEPT ![k] = e] ] ] - val expected: TlaEx = tla.except(tla.name("f"), tla.tuple(tla.name("i")), - tla.except(tla.appFun(tla.name("f"), tla.name("i")), tla.tuple(tla.name("j")), - tla.except(tla.appFun(tla.appFun(tla.name("f"), tla.name("i")), tla.name("j")), tla.tuple(tla.name("k")), - tla.name("e")))) - - assert(expected == sugarFree) + val input = + tla + .except(tla.name("f") ? "f3", tla.tuple(tla.name("i") ? "i", tla.name("j") ? "i", tla.name("k") ? "i") ? "i3", + tla.name("e") ? "i") + .typed(exceptTypes, "f3") + val output = desugarer.transform(input) + // output: series of LET-IN definitions + // LET t_1 == f + // t_2 == [t_1()[i][j] EXCEPT ![<>] = e] + // t_3 == [t_1()[i] EXCEPT ![<>] = t_2()] + // t_4 == [t_1() EXCEPT ![<>] = t_3()] + // IN t_4() + val defs = Seq( + tla + .declOp("t_1", tla.name("f") ? "f3") + .typedOperDecl(exceptTypes, "O3"), + tla + .declOp("t_2", + tla.except(callAndAccess("t_1", "f3", "i", "j"), tla.tuple(tla.name("k") ? "i") ? "i1", + tla.name("e") ? "i") ? "f1") + .typedOperDecl(exceptTypes, "O1"), + tla + .declOp("t_3", + tla.except(callAndAccess("t_1", "f3", "i"), tla.tuple(tla.name("j") ? "i") ? "i1", + callAndAccess("t_2", "f1")) ? "f2") + .typedOperDecl(exceptTypes, "O2"), + tla + .declOp("t_4", + tla.except(callAndAccess("t_1", "f1"), tla.tuple(tla.name("i") ? "i") ? "i1", + callAndAccess("t_3", "f3")) ? "f3") + .typedOperDecl(exceptTypes, "O3") + ) + + val expected: TlaEx = + tla + .letIn(callAndAccess("t_4", "f3"), defs: _*) + .typed(exceptTypes, "f3") + + assert(expected eqTyped output) + } + + test("EXCEPT with two updates") { + // input: [f EXCEPT ![i][j] = e1, ![k][l] = e2] + val input = + tla + .except(tla.name("f") ? "f2", tla.tuple(tla.name("i") ? "i", tla.name("j") ? "i") ? "i2", tla.name("e1") ? "i", + tla.tuple(tla.name("k") ? "i", tla.name("l") ? "i") ? "i2", tla.name("e2") ? "i") + .typed(exceptTypes, "f2") + val output = desugarer.transform(input) + // output: a series of definitions + // LET t_1 == f + // t_2 == [t_1()[i] EXCEPT ![<>] = e1] + // t_3 == [t_1() EXCEPT ![<>] = t_2()] + // t_4 == [t_3()[k] EXCEPT ![<>] = e2] + // t_5 == [t_3() EXCEPT ![<>] = t_4()] + // IN t_5() + val defs = Seq( + tla + .declOp("t_1", tla.name("f") ? "f2") + .typedOperDecl(exceptTypes, "O2"), + tla + .declOp("t_2", + tla.except(callAndAccess("t_1", "f2", "i"), tla.tuple(tla.name("j") ? "i") ? "i1", + tla.name("e1") ? "i") ? "f1") + .typedOperDecl(exceptTypes, "O1"), + tla.declOp("t_3", + tla.except(callAndAccess("t_1", "f2"), tla.tuple(tla.name("i") ? "i") ? "i1", + callAndAccess("t_2", "f1")) ? "f2") + typedOperDecl (exceptTypes, "O2"), + tla + .declOp("t_4", + tla.except(callAndAccess("t_3", "f2", "k"), tla.tuple(tla.name("l") ? "i") ? "i1", + tla.name("e2") ? "i") ? "f1") + .typedOperDecl(exceptTypes, "O1"), + tla + .declOp("t_5", + tla.except(callAndAccess("t_3", "f2"), tla.tuple(tla.name("k") ? "i") ? "i1", + callAndAccess("t_4", "f1")) ? "f2") + .typedOperDecl(exceptTypes, "O2") + ) + + val expected: TlaEx = + tla + .letIn(callAndAccess("t_5", "f2"), defs: _*) + .typed(exceptTypes, "f2") + + assert(expected eqTyped output) } test("""rewrite UNCHANGED x to x' = x""") { // input: x - val unchanged = tla.unchanged(tla.name("x")) - val sugarFree = desugarer.transform(unchanged) + val input = tla + .unchanged(tla.name("x") ? "i") + .typed(unchangedTypes, "b") + val output = desugarer.transform(input) // output: x' = x - val expected: TlaEx = tla.eql(tla.prime(tla.name("x")), tla.name("x")) - assert(expected == sugarFree) + val expected = + tla + .eql(tla.prime(tla.name("x") ? "i") ? "i", tla.name("x") ? "i") + .typed(unchangedTypes, "b") + assert(expected eqTyped output) } - test("""rewrite UNCHANGED <> to x' = x /\ y' = y""") { + test("""rewrite UNCHANGED <> >> to x' = x /\ y' = y""") { // input: <> >> - val unchanged = tla.unchangedTup(tla.name("x"), tla.tuple(tla.name("y"))) - val sugarFree = desugarer.transform(unchanged) + val input = + tla + .unchanged(tla.tuple(tla.name("x") ? "i", tla.tuple(tla.name("y") ? "b") ? "b1") ? "i_b1_2") + .typed(unchangedTypes, "b") + val output = desugarer.transform(input) // output: x' = x /\ y' = y val expected: TlaEx = - tla.and( - tla.eql(tla.prime(tla.name("x")), tla.name("x")), - tla.eql(tla.prime(tla.name("y")), tla.name("y")) - ) /// - assert(expected == sugarFree) + tla + .and( + tla.eql(tla.prime(tla.name("x") ? "i") ? "i", tla.name("x") ? "i") ? "b", + tla.eql(tla.prime(tla.name("y") ? "b") ? "b", tla.name("y") ? "b") ? "b" + ) + .typed(unchangedTypes, "b") + assert(expected eqTyped output) } test("""rewrite <> = <> to x = a /\ y = b""") { // This pattern looks like a parallel assignment. It stems from preprocessing of UNCHANGED and prime. // input: <> = <> - val parallel = - tla.eql(tla.tuple(tla.name("x"), tla.name("y")), tla.tuple(tla.name("a"), tla.name("b"))) + val input = + tla + .eql(tla.tuple(tla.name("x") ? "i", tla.name("y") ? "b") ? "ib_2", + tla.tuple(tla.name("a") ? "i", tla.name("b") ? "b") ? "ib_2") + .typed(unchangedTypes, "b") - val sugarFree = desugarer.transform(parallel) + val output = desugarer.transform(input) // output: x = a /\ y = b val expected: TlaEx = - tla.and( - tla.eql(tla.name("x"), tla.name("a")), - tla.eql(tla.name("y"), tla.name("b")) - ) /// - assert(expected == sugarFree) + tla + .and( + tla.eql(tla.name("x") ? "i", tla.name("a") ? "b") ? "b", + tla.eql(tla.name("y") ? "i", tla.name("b") ? "b") ? "b" + ) + .typed(unchangedTypes, "b") + assert(expected eqTyped output) } test("""rewrite <> /= <> to x /= a \/ y /= b""") { - val left = tla.tuple(tla.name("x"), tla.name("y")) - val right = tla.tuple(tla.name("a"), tla.name("b")) + val left = tla.tuple(tla.name("x") ? "i", tla.name("y") ? "b") ? "ib_2" + val right = tla.tuple(tla.name("a") ? "i", tla.name("b") ? "b") ? "ib_2" // input: <> /= <> - val parallel = tla.neql(left, right) + val parallel = tla.neql(left, right).typed(unchangedTypes, "b") - val sugarFree = desugarer.transform(parallel) + val output = desugarer.transform(parallel) // output: x /= a \/ y /= b - val expected: TlaEx = - tla.or( - tla.neql(tla.name("x"), tla.name("a")), - tla.neql(tla.name("y"), tla.name("b")) - ) /// - assert(expected == sugarFree) + val expected = + tla + .or( + tla.neql(tla.name("x") ? "i", tla.name("a") ? "b") ? "b", + tla.neql(tla.name("y") ? "i", tla.name("b") ? "b") ? "b" + ) + .typed(unchangedTypes, "b") + assert(expected eqTyped output) } test("""rewrite <> = <> to FALSE""") { - val left = tla.tuple(tla.name("x"), tla.name("y")) - val right = tla.tuple(tla.name("a"), tla.name("b"), tla.name("c")) + val left = tla.tuple(tla.name("x") ? "i", tla.name("y") ? "b") ? "ib_2" + val right = tla.tuple(tla.name("a") ? "i", tla.name("b") ? "b", tla.name("c") ? "i") ? "ibi_3" // input: <> = <> - val parallel = tla.eql(left, right) + val input = tla.eql(left, right).typed(unchangedTypes, "b") - val sugarFree = desugarer.transform(parallel) + val output = desugarer.transform(input) // output: FALSE - val expected: TlaEx = tla.bool(false) - assert(expected == sugarFree) + val expected: TlaEx = tla.bool(false).typed() + assert(expected eqTyped output) } test("""rewrite <> /= <> to TRUE""") { - val left = tla.tuple(tla.name("x"), tla.name("y")) - val right = tla.tuple(tla.name("a"), tla.name("b"), tla.name("c")) + val left = tla.tuple(tla.name("x") ? "i", tla.name("y") ? "b") ? "ib_2" + val right = tla.tuple(tla.name("a") ? "i", tla.name("b") ? "b", tla.name("c") ? "i") ? "ibi_3" // input: <> /= <> - val parallel = tla.neql(left, right) + val parallel = tla.neql(left, right).typed(unchangedTypes, "b") - val sugarFree = desugarer.transform(parallel) + val output = desugarer.transform(parallel) // output: TRUE - val expected: TlaEx = tla.bool(true) - assert(expected == sugarFree) + val expected: TlaEx = tla.bool(true).typed() + assert(expected eqTyped output) } test("unfold UNCHANGED <> >> to UNCHANGED <>") { // This is an idiom that was probably introduced by Diego Ongaro in Raft. // There is no added value in this construct, so it is just sugar. // We do the transformation right here. - val unchanged = tla.unchangedTup(tla.name("x"), tla.tuple(tla.name("y"), tla.name("z"))) + val unchanged = + tla + .unchanged(tla.tuple(tla.name("x") ? "i", + tla.tuple(tla.name("y") ? "i", tla.name("z") ? "b") ? "ib_2") ? "i_ib_2_2") + .typed(unchangedTypes, "b") val sugarFree = desugarer.transform(unchanged) val expected: TlaEx = - tla.and( - tla.eql(tla.prime(tla.name("x")), tla.name("x")), - tla.eql(tla.prime(tla.name("y")), tla.name("y")), - tla.eql(tla.prime(tla.name("z")), tla.name("z")) - ) /// - assert(expected == sugarFree) + tla + .and( + tla.eql(tla.prime(tla.name("x") ? "i") ? "i", tla.name("x") ? "i") ? "b", + tla.eql(tla.prime(tla.name("y") ? "i") ? "i", tla.name("y") ? "i") ? "b", + tla.eql(tla.prime(tla.name("z") ? "i") ? "i", tla.name("z") ? "i") ? "b" + ) + .typed(unchangedTypes, "b") + assert(expected eqTyped sugarFree) } test("""rewrite UNCHANGED <<>> to TRUE""") { // this is a regression for issue #375 // input: << >> - val unchanged = tla.unchangedTup() - val sugarFree = desugarer.transform(unchanged) + val input = tla.unchanged(tla.tuple() ? "et").typed(unchangedTypes, "b") + val output = desugarer.transform(input) // output: TRUE - val expected: TlaEx = tla.bool(true) - assert(expected == sugarFree) + val expected: TlaEx = tla.bool(true).typed() + assert(expected eqTyped output) } test("""rewrite UNCHANGED << <<>>, <<>> >> to TRUE""") { // this is a regression for issue #375 // input: << <<>>, <<>> >> - val unchanged = tla.unchangedTup(tla.unchangedTup(), tla.unchangedTup()) - val sugarFree = desugarer.transform(unchanged) + val input = tla + .unchanged(tla.tuple(tla.tuple() ? "et", tla.tuple() ? "et") ? "et_2") + .typed(unchangedTypes + ("et_2" -> TupT1(TupT1(), TupT1())), "b") + val output = desugarer.transform(input) // output: TRUE - val expected: TlaEx = tla.bool(true) - assert(expected == sugarFree) + val expected: TlaEx = tla.bool(true).typed() + assert(expected eqTyped output) } test("""rewrite UNCHANGED f[i] to (f[i])' = f[i]""") { // this is a regression for issue #471 // input: UNCHANGED f[i] - val app = tla.appFun(tla.name("f"), tla.name("i")) - val sugarFree = desugarer.transform(tla.unchangedTup(app)) + val app = tla + .appFun(tla.name("f") ? "f1", tla.name("i") ? "i") + .typed(unchangedTypes, "i") + val input = tla + .unchanged(app) + .typed(BoolT1()) + val output = desugarer.transform(input) // output: (f[i])' = f[i] - val expected: TlaEx = tla.eql(tla.prime(app), app) - assert(expected == sugarFree) + val expected: TlaEx = + tla + .eql(tla.prime(app) ? "b", app) + .typed(unchangedTypes, "b") + assert(expected eqTyped output) } test("simplify tuples in filters") { // TLA+ allows the user to write tuples in expanded form. We introduce tuples instead. // input: { <> >> \in XYZ: x = 3 /\ y = 4 } - val filter = - tla.filter(tla.tuple(tla.name("x"), tla.tuple(tla.name("y"), tla.name("z"))), tla.name("XYZ"), - tla.and(tla.eql(tla.name("x"), tla.int(3)), tla.eql(tla.name("y"), tla.int(4)))) - val sugarFree = desugarer.transform(filter) + val input = + tla + .filter(tla.tuple(tla.name("x") ? "i", + tla.tuple(tla.name("y") ? "i", tla.name("z") ? "i") ? "ii_2") ? "i_ii_2_2", + tla.name("XYZ") ? "I_II_2_2", + tla.and(tla.eql(tla.name("x") ? "i", tla.int(3)) ? "b", + tla.eql(tla.name("y") ? "i", tla.int(4)) ? "b") ? "b") + .typed(tupleTypes, "I_II_2_2") + val sugarFree = desugarer.transform(input) // output: { x_y_z \in XYZ: x_y_z[1] = 3 /\ x_y_z[2][1] = 4 } - val expected: TlaEx = - tla.filter(tla.name("x_y_z"), tla.name("XYZ"), - tla.and( - tla.eql(tla.appFun(tla.name("x_y_z"), tla.int(1)), tla.int(3)), - tla.eql(tla.appFun(tla.appFun(tla.name("x_y_z"), tla.int(2)), tla.int(1)), tla.int(4)) - )) //// - assert(expected == sugarFree) + val output = + tla + .filter(tla.name("x_y_z") ? "i_ii_2_2", tla.name("XYZ") ? "I_II_2_2", + tla.and( + tla.eql(tla.appFun(tla.name("x_y_z") ? "i_ii_2_2", tla.int(1)) ? "i", tla.int(3) ? "i") ? "b", + tla.eql(tla.appFun(tla.appFun(tla.name("x_y_z") ? "i_ii_2_2", tla.int(2)) ? "i", tla.int(1)) ? "i", + tla.int(4)) ? "b" + ) ? "b") + .typed(tupleTypes, "I_II_2_2") + assert(output eqTyped sugarFree) } test("simplify tuples in maps") { // TLA+ allows the user to write tuples in expanded form. We introduce tuples instead. // input: { <> >> \in XYZ |-> x + y } val map = - tla.map(tla.plus(tla.name("x"), tla.name("y")), tla.tuple(tla.name("x"), tla.tuple(tla.name("y"), tla.name("z"))), - tla.name("XYZ")) - val sugarFree = desugarer.transform(map) + tla + .map(tla.plus(tla.name("x") ? "i", tla.name("y") ? "i") ? "i", + tla.tuple(tla.name("x") ? "i", tla.tuple(tla.name("y") ? "i", tla.name("z") ? "i") ? "ii_2") ? "i_ii_2_2", + tla.name("XYZ") ? "I_II_2_2") + .typed(tupleTypes, "II_2") + val output = desugarer.transform(map) // output: { x_y_z \in XYZ: x_y_z[1] + x_y_z[2][1] } - val expected: TlaEx = - tla.map( - tla.plus(tla.appFun(tla.name("x_y_z"), tla.int(1)), - tla.appFun(tla.appFun(tla.name("x_y_z"), tla.int(2)), tla.int(1))), - tla.name("x_y_z"), - tla.name("XYZ") - ) //// - assert(expected == sugarFree) + val expected = + tla + .map( + tla.plus(tla.appFun(tla.name("x_y_z") ? "i_ii_2_2", tla.int(1)) ? "i", + tla.appFun(tla.appFun(tla.name("x_y_z") ? "i_ii_2_2", tla.int(2)) ? "ii_2", tla.int(1)) ? "i") ? "i", + tla.name("x_y_z") ? "i_ii_2_2", + tla.name("XYZ") ? "I_II_2_2" + ) + .typed(tupleTypes, "II_2") + assert(expected eqTyped output) } test("simplify tuples in existentials") { // TLA+ allows the user to write tuples in expanded form. We introduce tuples instead. // input: \E <> >> \in XYZ: x = 3 /\ y = 4 } - val filter = - tla.exists(tla.tuple(tla.name("x"), tla.tuple(tla.name("y"), tla.name("z"))), tla.name("XYZ"), - tla.and(tla.eql(tla.name("x"), tla.int(3)), tla.eql(tla.name("y"), tla.int(4)))) - val sugarFree = desugarer.transform(filter) + val input = + tla + .exists(tla.tuple(tla.name("x") ? "i", + tla.tuple(tla.name("y") ? "i", tla.name("z") ? "i") ? "ii_2") ? "i_ii_2_2", + tla.name("XYZ") ? "I_II_2_2", + tla.and(tla.eql(tla.name("x") ? "i", tla.int(3)) ? "b", + tla.eql(tla.name("y") ? "i", tla.int(4)) ? "b") ? "b") + .typed(tupleTypes, "b") + val sugarFree = desugarer.transform(input) // output: \E x_y_z \in XYZ: x_y_z[1] = 3 /\ x_y_z[2][1] = 4 } - val expected: TlaEx = - tla.exists(tla.name("x_y_z"), tla.name("XYZ"), - tla.and( - tla.eql(tla.appFun(tla.name("x_y_z"), tla.int(1)), tla.int(3)), - tla.eql(tla.appFun(tla.appFun(tla.name("x_y_z"), tla.int(2)), tla.int(1)), tla.int(4)) - )) //// - assert(expected == sugarFree) + val output = + tla + .exists(tla.name("x_y_z") ? "i_ii_2_2", tla.name("XYZ") ? "I_II_2_2", + tla.and( + tla.eql(tla.appFun(tla.name("x_y_z") ? "i_ii_2_2", tla.int(1)) ? "i", tla.int(3)) ? "b", + tla.eql(tla.appFun(tla.appFun(tla.name("x_y_z") ? "i_ii_2_2", tla.int(2)) ? "i", tla.int(1)) ? "i", + tla.int(4)) ? "b" + ) ? "b") + .typed(tupleTypes, "b") + assert(output eqTyped sugarFree) } test("simplify tuples in universals") { // TLA+ allows the user to write tuples in expanded form. We introduce tuples instead. // input: \A <> >> \in XYZ: x = 3 /\ y = 4 } - val filter = - tla.forall(tla.tuple(tla.name("x"), tla.tuple(tla.name("y"), tla.name("z"))), tla.name("XYZ"), - tla.and(tla.eql(tla.name("x"), tla.int(3)), tla.eql(tla.name("y"), tla.int(4)))) - val sugarFree = desugarer.transform(filter) + val input = + tla + .forall(tla.tuple(tla.name("x") ? "i", + tla.tuple(tla.name("y") ? "i", tla.name("z") ? "i") ? "ii_2") ? "i_ii_2_2", + tla.name("XYZ") ? "I_II_2_2", + tla.and(tla.eql(tla.name("x") ? "i", tla.int(3)) ? "b", + tla.eql(tla.name("y") ? "i", tla.int(4)) ? "b") ? "b") + .typed(tupleTypes, "b") + val sugarFree = desugarer.transform(input) // output: \A x_y_z \in XYZ: x_y_z[1] = 3 /\ x_y_z[2][1] = 4 } - val expected: TlaEx = - tla.forall(tla.name("x_y_z"), tla.name("XYZ"), - tla.and( - tla.eql(tla.appFun(tla.name("x_y_z"), tla.int(1)), tla.int(3)), - tla.eql(tla.appFun(tla.appFun(tla.name("x_y_z"), tla.int(2)), tla.int(1)), tla.int(4)) - )) //// - assert(expected == sugarFree) + val output = + tla + .forall(tla.name("x_y_z") ? "i_ii_2_2", tla.name("XYZ") ? "I_II_2_2", + tla.and( + tla.eql(tla.appFun(tla.name("x_y_z") ? "i_ii_2_2", tla.int(1)) ? "i", tla.int(3)) ? "b", + tla.eql(tla.appFun(tla.appFun(tla.name("x_y_z") ? "i_ii_2_2", tla.int(2)) ? "i", tla.int(1)) ? "i", + tla.int(4)) ? "b" + ) ? "b") + .typed(tupleTypes, "b") + assert(output eqTyped sugarFree) } test("simplify tuples in functions") { // TLA+ allows the user to write tuples in expanded form. We introduce tuples instead. // input: [<> >> \in XYZ |-> x + y] - val map = - tla.funDef(tla.plus(tla.name("x"), tla.name("y")), - tla.tuple(tla.name("x"), tla.tuple(tla.name("y"), tla.name("z"))), tla.name("XYZ")) - val sugarFree = desugarer.transform(map) - // output: [ x_y_z \in XYZ |-> x_y_z[1] + x_y_z[2][1] ] - val expected: TlaEx = - tla.funDef( - tla.plus(tla.appFun(tla.name("x_y_z"), tla.int(1)), - tla.appFun(tla.appFun(tla.name("x_y_z"), tla.int(2)), tla.int(1))), - tla.name("x_y_z"), - tla.name("XYZ") - ) //// - assert(expected == sugarFree) + val input = + tla + .funDef(tla.plus(tla.name("x") ? "i", tla.name("y") ? "i") ? "i", + tla.tuple(tla.name("x") ? "i", tla.tuple(tla.name("y") ? "i", tla.name("z") ? "i") ? "ii_2") ? "i_ii_2_2", + tla.name("XYZ") ? "I_II_2_2") + .typed(tupleTypes, "i_ii_2_2_to_ii_2") + val output = desugarer.transform(input) + // output: { x_y_z \in XYZ: x_y_z[1] + x_y_z[2][1] } + val expected = + tla + .funDef( + tla.plus(tla.appFun(tla.name("x_y_z") ? "i_ii_2_2", tla.int(1)) ? "i", + tla.appFun(tla.appFun(tla.name("x_y_z") ? "i_ii_2_2", tla.int(2)) ? "ii_2", tla.int(1)) ? "i") ? "i", + tla.name("x_y_z") ? "i_ii_2_2", + tla.name("XYZ") ? "I_II_2_2" + ) + .typed(tupleTypes, "i_ii_2_2_to_ii_2") + + assert(expected eqTyped output) } test("keep one argument functions") { // make sure that a function of a single argument does not get modified, e.g., no tuples added // input: [x \in X |-> {x}] - val fundef: TlaEx = - tla.funDef(tla.enumSet(tla.name("x")), tla.name("x"), tla.name("X")) - val sugarFree = desugarer.transform(fundef) - assert(fundef == sugarFree) + val input = + tla + .funDef(tla.enumSet(tla.name("x") ? "i") ? "I", tla.name("x") ? "i", tla.name("X") ? "I") + .typed(tupleTypes, "i_to_I") + val output = desugarer.transform(input) + assert(input eqTyped output) } test("simplify multi-argument functions") { // The user may write multi-argument functions, which we collapse into tuples // input: [ x \in X, y \in Y |-> x + y ] val map = - tla.funDef(tla.plus(tla.name("x"), tla.name("y")), tla.name("x"), tla.name("X"), tla.name("y"), tla.name("Y")) + tla + .funDef(tla.plus(tla.name("x") ? "i", tla.name("y") ? "i") ? "i", tla.name("x") ? "i", tla.name("X") ? "I", + tla.name("y") ? "i", tla.name("Y") ? "I") + .typed(tupleTypes, "ii_to_i") val sugarFree = desugarer.transform(map) // output: [ x_y \in X \X Y |-> x_y[1] + x_y[2] ] val expected: TlaEx = - tla.funDef( - tla.plus(tla.appFun(tla.name("x_y"), tla.int(1)), tla.appFun(tla.name("x_y"), tla.int(2))), - tla.name("x_y"), - tla.times(tla.name("X"), tla.name("Y")) - ) //// - assert(expected == sugarFree) + tla + .funDef( + tla.plus(tla.appFun(tla.name("x_y") ? "ii_2", tla.int(1)) ? "i", + tla.appFun(tla.name("x_y") ? "ii_2", tla.int(2)) ? "i") ? "i", + tla.name("x_y") ? "ii_2", + tla.times(tla.name("X") ? "I", tla.name("Y") ? "I") ? "II_2" + ) + .typed(tupleTypes, "ii_to_i") + assert(expected eqTyped sugarFree) } test("simplify tuples in recursive functions") { // TLA+ allows the user to write tuples in expanded form. We introduce tuples instead. // input: f[x \in S, y \in T] == x + y - val map = - tla.recFunDef(tla.plus(tla.name("x"), tla.name("y")), tla.name("x"), tla.name("S"), tla.name("y"), tla.name("T")) - val sugarFree = desugarer.transform(map) + val input = + tla + .recFunDef(tla.plus(tla.name("x") ? "i", tla.name("y") ? "i") ? "i", tla.name("x") ? "i", tla.name("S") ? "I", + tla.name("y") ? "i", tla.name("T") ? "I") + .typed(tupleTypes, "ii_to_i") + val output = desugarer.transform(input) // output: f[x_y \in S \X T] == x_y[1] + x_y[2] val expected: TlaEx = - tla.recFunDef( - tla.plus(tla.appFun(tla.name("x_y"), tla.int(1)), tla.appFun(tla.name("x_y"), tla.int(2))), - tla.name("x_y"), - tla.times(tla.name("S"), tla.name("T")) - ) //// - assert(expected == sugarFree) + tla + .recFunDef( + tla.plus(tla.appFun(tla.name("x_y") ? "ii_2", tla.int(1)) ? "i", + tla.appFun(tla.name("x_y") ? "ii_2", tla.int(2)) ? "i") ? "i", + tla.name("x_y") ? "ii_2", + tla.times(tla.name("S") ? "I", tla.name("T") ? "I") ? "II_2" + ) + .typed(tupleTypes, "ii_to_i") + assert(expected eqTyped output) } test("keep one argument recursive functions") { // make sure that a function of a single argument does not get modified, e.g., no tuples added // input: [x \in X |-> {x}] - val recFun: TlaEx = - tla.recFunDef(tla.enumSet(tla.name("x")), tla.name("x"), tla.name("X")) - val sugarFree = desugarer.transform(recFun) - assert(recFun == sugarFree) + val input: TlaEx = + tla + .recFunDef(tla.enumSet(tla.name("x") ? "i") ? "I", tla.name("x") ? "i", tla.name("X") ? "I") + .typed(tupleTypes, "i_to_I") + val output = desugarer.transform(input) + assert(input eqTyped output) } test("accept calls to user-defined operators") { + val types = Map("i" -> IntT1(), "F" -> OperT1(Seq(), IntT1())) // Foo(1) - val app: TlaEx = tla.appOp(tla.name("Foo"), tla.int(1)) - val sugarFree = desugarer(app) + val input = + tla + .appOp(tla.name("Foo") ? "F", tla.int(1) ? "i") + .typed(types, "i") + val output = desugarer(input) // do nothing and do not complain - assert(sugarFree == app) + assert(output eqTyped input) } test("accept n-ary let-in definitions") { + val types = Map("i" -> IntT1(), "F" -> OperT1(Seq(), IntT1())) // Foo(1) - val fooDef = tla.declOp("Foo", tla.name("x"), SimpleFormalParam("x")).untypedOperDecl() - val letIn: TlaEx = tla.letIn(tla.appOp(tla.name("Foo"), tla.int(1)), fooDef) - val sugarFree = desugarer(letIn) + val fooDef = tla + .declOp("Foo", tla.name("x") ? "i", SimpleFormalParam("x")) + .typedOperDecl(types, "F") + val input = tla + .letIn(tla.appOp(tla.name("Foo") ? "F", tla.int(1) ? "i") ? "i", fooDef) + .typed(types, "i") + val output = desugarer(input) // do nothing and do not complain - assert(sugarFree == letIn) + assert(output == input) } } diff --git a/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestExprOptimizer.scala b/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestExprOptimizer.scala index d32f961631..dc28c8bc1d 100644 --- a/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestExprOptimizer.scala +++ b/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestExprOptimizer.scala @@ -1,8 +1,9 @@ package at.forsyte.apalache.tla.pp +import at.forsyte.apalache.tla.lir.{BoolT1, FunT1, IntT1, OperT1, RecT1, SetT1} import at.forsyte.apalache.tla.lir.convenience._ import at.forsyte.apalache.tla.lir.transformations.impl.TrackerWithListeners -import at.forsyte.apalache.tla.typecheck.TypedPredefs._ +import at.forsyte.apalache.tla.lir.TypedPredefs._ import at.forsyte.apalache.tla.typecheck._ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner diff --git a/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestInlinerofUserOper.scala b/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestInlinerofUserOper.scala new file mode 100644 index 0000000000..72e68ddeb6 --- /dev/null +++ b/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestInlinerofUserOper.scala @@ -0,0 +1,100 @@ +package at.forsyte.apalache.tla.pp + +import at.forsyte.apalache.tla.lir.convenience.tla +import at.forsyte.apalache.tla.lir.storage.BodyMapFactory +import at.forsyte.apalache.tla.lir.transformations.impl.IdleTracker +import at.forsyte.apalache.tla.lir.transformations.standard.InlinerOfUserOper +import at.forsyte.apalache.tla.lir.{BoolT1, IntT1, OperT1, SimpleFormalParam, TestingPredefs} +import at.forsyte.apalache.tla.lir.TypedPredefs._ +import org.junit.runner.RunWith +import org.scalatest.FunSuite +import org.scalatest.junit.JUnitRunner + +@RunWith(classOf[JUnitRunner]) +class TestInlinerofUserOper extends FunSuite with TestingPredefs { + + import tla._ + + test("Test Inline") { + val types = + Map("i" -> IntT1(), "b" -> BoolT1(), "U" -> OperT1(Seq(), IntT1()), "C" -> OperT1(Seq(IntT1()), IntT1())) + val cBody = plus(n_x ? "i", int(1)) + .typed(types, "i") + // C(x) == x + 1 + // B == k + // A == B() + val cDecl = declOp("C", cBody, SimpleFormalParam("x")) + .typedOperDecl(types, "C") + val aDecl = declOp("A", appOp(tla.name("B") ? "U") ? "i") + .typedOperDecl(types, "U") + val bDecl = declOp("B", name("k") ? "i") + .typedOperDecl(types, "C") + + val operDecls = Seq(aDecl, bDecl, cDecl) + + val bodies = BodyMapFactory.makeFromDecls(operDecls) + + val transformation = InlinerOfUserOper(bodies, new IdleTracker()) + + // B + val ex1 = tla + .name("B") + .typed(types, "U") + // B() + val ex2 = appOp(name("B") ? "U") + .typed(types, "i") + // A + val ex3 = tla + .name("A") + .typed(types, "U") + // A() + val ex4 = appOp(name("A") ? "U") + .typed(types, "i") + // 1 = 0 \/ C(A()) >= 0 + val ex5 = or( + eql(int(1), int(0)) ? "b", + ge(appOp(tla.name("C") ? "C", appOp(tla.name("A") ? "U") ? "i") ? "i", int(0)) ? "b" + ) + .typed(types, "b") + // LET X == C(p) IN X() + val ex6 = letIn( + appOp(tla.name("X") ? "U") ? "i", + declOp("X", appOp(tla.name("C") ? "C", tla.name("p") ? "i") ? "i") + .typedOperDecl(types, "U") + ).typed(types, "i") + + // no inlining, as B is just passed by name + val expected1 = tla.name("B").typed(types, "U") + val expected2 = tla.name("k").typed(types, "i") + val expected3 = tla.name("A").typed(types, "U") + val expected4 = tla.name("k").typed(types, "i") + // the bodies of A and C are inlined + val expected5 = or( + eql(int(1), int(0)) ? "b", + ge(plus(tla.name("k") ? "i", int(1)) ? "i", int(0)) ? "b" + ).typed(types, "b") + // C is inlined, but X is not + val expected6 = letIn( + appOp(tla.name("X") ? "U") ? "i", + declOp("X", plus(tla.name("p") ? "i", int(1)) ? "i") + .typedOperDecl(types, "U") + ).typed(types, "i") + + assert(expected1 == transformation(ex1)) + assert(expected2 == transformation(ex2)) + assert(expected3 == transformation(ex3)) + assert(expected4 == transformation(ex4)) + assert(expected5 == transformation(ex5)) + assert(expected6 == transformation(ex6)) + + val applyCwithNoArgs = appOp(tla.name("C") ? "U").typed(types, "i") + val applyCwithTwoArgs = appOp(tla.name("C") ? "U", tla.name("a") ? "i", tla.name("b") ? "i") + .typed(types, "i") + assertThrows[IllegalArgumentException] { + transformation(applyCwithNoArgs) + } + assertThrows[IllegalArgumentException] { + transformation(applyCwithTwoArgs) + } + } +} diff --git a/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestKeramelizer.scala b/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestKeramelizer.scala index 43182e0acf..be632b33fc 100644 --- a/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestKeramelizer.scala +++ b/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestKeramelizer.scala @@ -1,13 +1,12 @@ package at.forsyte.apalache.tla.pp -import at.forsyte.apalache.tla.lir.TlaEx +import at.forsyte.apalache.tla.lir.{BoolT1, FunT1, IntT1, RecT1, SetT1, StrT1, TlaEx, TupT1} import at.forsyte.apalache.tla.lir.transformations.impl.TrackerWithListeners import org.junit.runner.RunWith import org.scalatest.{BeforeAndAfterEach, FunSuite} import org.scalatest.junit.JUnitRunner import at.forsyte.apalache.tla.lir.convenience._ -import at.forsyte.apalache.tla.typecheck.{BoolT1, FunT1, IntT1, RecT1, SetT1, StrT1, TupT1} -import at.forsyte.apalache.tla.typecheck.TypedPredefs._ +import at.forsyte.apalache.tla.lir.TypedPredefs._ @RunWith(classOf[JUnitRunner]) class TestKeramelizer extends FunSuite with BeforeAndAfterEach { diff --git a/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestNormalizer.scala b/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestNormalizer.scala index bdcb4594ff..e20f09efb5 100644 --- a/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestNormalizer.scala +++ b/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestNormalizer.scala @@ -1,8 +1,9 @@ package at.forsyte.apalache.tla.pp +import at.forsyte.apalache.tla.lir.BoolT1 import at.forsyte.apalache.tla.lir.convenience._ import at.forsyte.apalache.tla.lir.transformations.impl.IdleTracker -import at.forsyte.apalache.tla.typecheck.TypedPredefs._ +import at.forsyte.apalache.tla.lir.TypedPredefs._ import at.forsyte.apalache.tla.typecheck._ import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner diff --git a/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestOperAppToLetInDef.scala b/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestOperAppToLetInDef.scala index b9800efc04..e79999a536 100644 --- a/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestOperAppToLetInDef.scala +++ b/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestOperAppToLetInDef.scala @@ -3,9 +3,8 @@ package at.forsyte.apalache.tla.pp import at.forsyte.apalache.tla.lir.convenience.tla import at.forsyte.apalache.tla.lir.oper.{TlaArithOper, TlaOper} import at.forsyte.apalache.tla.lir._ -import at.forsyte.apalache.tla.typecheck.TypedPredefs._ +import TypedPredefs._ import at.forsyte.apalache.tla.lir.transformations.impl.TrackerWithListeners -import at.forsyte.apalache.tla.typecheck.{BoolT1, IntT1, OperT1, SetT1, TupT1} import org.junit.runner.RunWith import org.scalatest.{BeforeAndAfterEach, FunSuite} import org.scalatest.junit.JUnitRunner diff --git a/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestParameterNormalizer.scala b/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestParameterNormalizer.scala index a06a7a785d..a2655c92f0 100644 --- a/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestParameterNormalizer.scala +++ b/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestParameterNormalizer.scala @@ -1,6 +1,6 @@ package at.forsyte.apalache.tla.pp -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.lir.TypedPredefs._ import at.forsyte.apalache.tla.lir.convenience.tla import at.forsyte.apalache.tla.lir.oper.TlaOper import at.forsyte.apalache.tla.lir.transformations.impl.TrackerWithListeners @@ -12,7 +12,7 @@ import org.scalatest.{BeforeAndAfterEach, FunSuite} @RunWith(classOf[JUnitRunner]) class TestParameterNormalizer extends FunSuite with BeforeAndAfterEach with TestingPredefs { - val noTracker = TrackerWithListeners() + private val noTracker = TrackerWithListeners() val decisionFn: TlaOperDecl => Boolean = { _ => true } var parNorm = new ParameterNormalizer(new UniqueNameGenerator, noTracker, decisionFn) @@ -23,57 +23,99 @@ class TestParameterNormalizer extends FunSuite with BeforeAndAfterEach with Test test("Nullary: No-op") { // A == 1 - val decl = tla.declOp("A", tla.int(1)).untypedOperDecl() + val input = tla + .declOp("A", tla.int(1)) + .typedOperDecl(OperT1(Seq(), IntT1())) - val pnf = parNorm.normalizeDeclaration(decl) - - assert(pnf == decl) + val output = parNorm.normalizeDeclaration(input) + assert(output == input && output.typeTag == input.typeTag) } test("Simple parameter") { - + val types = Map("i" -> IntT1(), "A" -> OperT1(Seq(IntT1()), IntT1()), "P" -> OperT1(Seq(), IntT1())) // A(p) == p - val decl = tla.declOp("A", n_p, "p").untypedOperDecl() + val input = tla + .declOp("A", tla.name("p") ? "i", "p") + .typedOperDecl(types, "A") + + // A(p) == + // LET new == p IN + // new + val output = parNorm.normalizeDeclaration(input) - val pnf = parNorm.normalizeDeclaration(decl) + output match { + case d @ TlaOperDecl(name, List(SimpleFormalParam(p)), body) => + assert(input.typeTag == d.typeTag) - val assertCond = pnf match { - case TlaOperDecl(name, SimpleFormalParam(p) :: Nil, body) => body match { - case LetInEx(letInBody, TlaOperDecl(newName, Nil, NameEx(`p`))) => - name != newName && letInBody == OperEx(TlaOper.apply, NameEx(newName)) - case _ => false + case ex1 @ LetInEx(letInBody, TlaOperDecl(newName, Nil, NameEx(paramName))) => + assert(name != newName) + assert(p == paramName) + assert(Typed(IntT1()) == ex1.typeTag) + + letInBody match { + case ex2 @ OperEx(TlaOper.apply, nested @ NameEx(nestedName)) => + assert(Typed(IntT1()) == ex2.typeTag) + assert(nestedName == newName) + assert(Typed(types("P")) == nested.typeTag) + + case _ => + fail("expected OperEx") + } + + case _ => + fail("expected LetInEx") } - case _ => false - } - - assert(assertCond) + case _ => + fail("expected TlaOperDecl") + } } test("HO parameter") { + val types = Map("i" -> IntT1(), "T" -> OperT1(Seq(IntT1()), IntT1()), + "A" -> OperT1(Seq(OperT1(Seq(IntT1()), IntT1())), IntT1())) // A(T(_)) == T(0) - val decl = tla.declOp("A", tla.appOp(n_T, tla.int(0)), ("T", 1)).untypedOperDecl() + val input = tla + .declOp("A", tla.appOp(n_T, tla.int(0)) ? "i", ("T", 1)) + .typedOperDecl(types, "A") + + val output = parNorm.normalizeDeclaration(input) - val pnf = parNorm.normalizeDeclaration(decl) + // A(T(_)) == + // LET NewName(p) == T(p) IN + // NewName(0) + output match { + case d @ TlaOperDecl(name, List(OperFormalParam(opName, 1)), body) => + assert(input.typeTag == d.typeTag) - val assertCond = pnf match { - case TlaOperDecl(name, OperFormalParam(opName, 1) :: Nil, body) => body match { - case LetInEx(letInBody, TlaOperDecl(newName, SimpleFormalParam(fakeArg) :: Nil, OperEx(TlaOper.apply, NameEx( - `opName`), NameEx(arg)))) => - arg == fakeArg && - name != newName && - letInBody == OperEx(TlaOper.apply, NameEx(newName), 0) - case _ => false + case letin @ LetInEx(letInBody, TlaOperDecl(newName, List(SimpleFormalParam(intermediateParam)), + appex @ OperEx(TlaOper.apply, nex @ NameEx(appliedOperName), NameEx(arg)))) => + assert(opName == appliedOperName) + assert(arg == intermediateParam) + assert(Typed(IntT1()) == letin.typeTag) + assert(Typed(IntT1()) == appex.typeTag) + assert(Typed(types("T")) == nex.typeTag) + + letInBody match { + case appex2 @ OperEx(TlaOper.apply, nex2 @ NameEx(innerName), arg2) => + assert(newName == innerName) + assert(tla.int(0).typed() == arg2) + assert(Typed(IntT1()) == appex2.typeTag) + assert(Typed(types("T")) == nex2.typeTag) + + case _ => fail("Expected OperEx") + } + + case _ => fail("expected LetInEx") } - case _ => false - } - - assert(assertCond) + case _ => + fail("expected TlaOperDecl") + } } } diff --git a/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestTlcConfigImporter.scala b/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestTlcConfigImporter.scala index eef05a3d2e..63d9bcdd2a 100644 --- a/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestTlcConfigImporter.scala +++ b/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestTlcConfigImporter.scala @@ -4,12 +4,13 @@ import at.forsyte.apalache.io.annotations.store._ import at.forsyte.apalache.io.tlc.TlcConfigParserApalache import at.forsyte.apalache.tla.imp.SanyImporter import at.forsyte.apalache.tla.imp.src.SourceStore +import at.forsyte.apalache.tla.lir.Untyped import at.forsyte.apalache.tla.lir.io.PrettyWriter import at.forsyte.apalache.tla.lir.transformations.impl.IdleTracker -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import at.forsyte.apalache.tla.typecheck.{MultiTypeCheckerListener, TypeCheckerTool} import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner -import org.scalatest.{BeforeAndAfterEach, FunSuite} +import org.scalatest.{Assertion, BeforeAndAfterEach, FunSuite} import java.io.{PrintWriter, StringWriter} import scala.io.Source @@ -19,19 +20,28 @@ class TestTlcConfigImporter extends FunSuite with BeforeAndAfterEach { private var sourceStore: SourceStore = _ private var annotationStore: AnnotationStore = _ private var sanyImporter: SanyImporter = _ + private var typeChecker: TypeCheckerTool = _ override def beforeEach() { sourceStore = new SourceStore() annotationStore = createAnnotationStore() sanyImporter = new SanyImporter(sourceStore, annotationStore) + typeChecker = new TypeCheckerTool(annotationStore, inferPoly = false) } - def configureAndCompare(tla: String, tlc: String, expected: String) = { + def configureAndCompare(tla: String, tlc: String, expected: String): Assertion = { val config = TlcConfigParserApalache(tlc) val (rootName, modules) = sanyImporter.loadFromSource("test", Source.fromString(tla)) + val mod = modules(rootName) - val mod2 = new TlcConfigImporter(config, new IdleTracker())(mod) + // run the type checker + val typedModule = + typeChecker + .checkAndTag(new IdleTracker(), new MultiTypeCheckerListener(), defaultTag = { _ => Untyped() }, mod) + .get + + val mod2 = new TlcConfigImporter(config, new IdleTracker())(typedModule) val stringWriter = new StringWriter() val printWriter = new PrintWriter(stringWriter) val writer = new PrettyWriter(printWriter, 80) @@ -136,6 +146,10 @@ class TestTlcConfigImporter extends FunSuite with BeforeAndAfterEach { test("CONSTANT replacements") { configureAndCompare( """---- MODULE test ---- + |B == 1 + |D == "hello" + |Init == TRUE + |Next == TRUE |================================ """.stripMargin, """ @@ -147,6 +161,14 @@ class TestTlcConfigImporter extends FunSuite with BeforeAndAfterEach { """.stripMargin, """--------------------------------- MODULE test --------------------------------- | + |B == 1 + | + |D == "hello" + | + |Init == TRUE + | + |Next == TRUE + | |OVERRIDE_A == B | |OVERRIDE_C == D @@ -163,6 +185,12 @@ class TestTlcConfigImporter extends FunSuite with BeforeAndAfterEach { test("CONSTANT assignments and replacements") { configureAndCompare( """---- MODULE test ---- + |M == 1 + |B == "foo" + |L == TRUE + |D == 3 + |Init == TRUE + |Next == TRUE |================================ """.stripMargin, """ @@ -176,6 +204,18 @@ class TestTlcConfigImporter extends FunSuite with BeforeAndAfterEach { """.stripMargin, """--------------------------------- MODULE test --------------------------------- | + |M == 1 + | + |B == "foo" + | + |L == TRUE + | + |D == 3 + | + |Init == TRUE + | + |Next == TRUE + | |OVERRIDE_N == "ModelValue_M" | |OVERRIDE_K == "ModelValue_L" @@ -300,6 +340,11 @@ class TestTlcConfigImporter extends FunSuite with BeforeAndAfterEach { test("all features") { configureAndCompare( """---- MODULE test ---- + |M == 1 + |B == "foo" + |Init == TRUE + |Next == TRUE + |Prop == TRUE |================================ """.stripMargin, """ @@ -319,6 +364,16 @@ class TestTlcConfigImporter extends FunSuite with BeforeAndAfterEach { """.stripMargin, """--------------------------------- MODULE test --------------------------------- | + |M == 1 + | + |B == "foo" + | + |Init == TRUE + | + |Next == TRUE + | + |Prop == TRUE + | |OVERRIDE_N == "ModelValue_M" | |OVERRIDE_A == B diff --git a/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestUnroller.scala b/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestUnroller.scala index b2feb2ddfa..4230d18fe2 100644 --- a/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestUnroller.scala +++ b/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestUnroller.scala @@ -1,36 +1,48 @@ package at.forsyte.apalache.tla.pp -import at.forsyte.apalache.tla.lir.convenience.tla import at.forsyte.apalache.tla.lir._ +import at.forsyte.apalache.tla.lir.convenience.tla import at.forsyte.apalache.tla.lir.oper.TlaOper -import at.forsyte.apalache.tla.lir.transformations.impl.TrackerWithListeners +import at.forsyte.apalache.tla.lir.transformations.impl.IdleTracker import at.forsyte.apalache.tla.lir.transformations.standard.IncrementalRenaming -import at.forsyte.apalache.tla.lir.values.TlaInt -import at.forsyte.apalache.tla.lir.UntypedPredefs._ +import TypedPredefs._ import org.junit.runner.RunWith -import org.scalatest.{BeforeAndAfterEach, FunSuite} import org.scalatest.junit.JUnitRunner +import org.scalatest.{BeforeAndAfterEach, FunSuite} import scala.math.BigInt @RunWith(classOf[JUnitRunner]) class TestUnroller extends FunSuite with BeforeAndAfterEach with TestingPredefs { - val noTracker = TrackerWithListeners() + private val noTracker = new IdleTracker() private var unroller = new Unroller(new UniqueNameGenerator, noTracker, new IncrementalRenaming(noTracker)) override def beforeEach(): Unit = { unroller = new Unroller(new UniqueNameGenerator, noTracker, new IncrementalRenaming(noTracker)) } - def exAsDecl(pa: (String, TlaEx)): TlaOperDecl = TlaOperDecl(pa._1, List.empty, pa._2) + def exAsDecl(pa: (String, TlaEx)): TlaOperDecl = tla + .declOp(pa._1, pa._2) + .typedOperDecl(OperT1(Seq(), IntT1())) test("No-op") { + val strToInt = OperT1(Seq(StrT1()), IntT1()) + val types = Map("b" -> BoolT1(), "i" -> IntT1(), "T" -> strToInt) + val tDecl = tla + .declOp("T", tla.name("p").typed(StrT1()), "p") + .typedOperDecl(strToInt) + val dBody = tla + .letIn(tla.appOp(n_T ? "T", tla.str("abc")) ? "i", tDecl) + .typed(types, "i") + val cBody = tla + .and(n_x ? "i", n_P ? "b") + .typed(Map("b" -> BoolT1(), "i" -> IntT1()), "b") val decls = Seq[(String, TlaEx)]( - ("A", "1"), - ("B", 0), - ("C", tla.and(n_x, n_P)), - ("D", tla.letIn(n_T, tla.declOp("T", tla.name("p"), "p").untypedOperDecl())) + ("A", tla.str("1").typed()), + ("B", tla.int(0).typed()), + ("C", cBody), + ("D", dBody) ) map exAsDecl val module = new TlaModule("M", decls) @@ -41,17 +53,24 @@ class TestUnroller extends FunSuite with BeforeAndAfterEach with TestingPredefs } test("0 step: ParamNormalForm") { + val strToInt = OperT1(Seq(StrT1()), IntT1()) + val types = Map("b" -> BoolT1(), "i" -> IntT1(), "s" -> StrT1(), "T" -> strToInt) val name = "A" + val aBody = tla + .appOp(n_A ? "T", n_p ? "s") + .typed(types, "i") // A(p) == A(p) - val recDecl = tla.declOp(name, tla.appOp(n_A, n_p), "p").untypedOperDecl() + val recDecl = tla + .declOp(name, aBody, "p") + .typedOperDecl(strToInt) recDecl.isRecursive = true val defaultVal: BigInt = 42 val decls = recDecl +: (Seq[(String, TlaEx)]( - (Unroller.UNROLL_TIMES_PREFIX + name, 0), - (Unroller.UNROLL_DEFAULT_PREFIX + name, defaultVal.intValue) + (Unroller.UNROLL_TIMES_PREFIX + name, tla.int(0).typed(IntT1())), + (Unroller.UNROLL_DEFAULT_PREFIX + name, tla.bigInt(defaultVal.intValue).typed(IntT1())) ) map exAsDecl) val module = new TlaModule("M", decls) @@ -60,72 +79,107 @@ class TestUnroller extends FunSuite with BeforeAndAfterEach with TestingPredefs val newAOpt = unrolled.operDeclarations.find(_.name == name) - val assertCond = newAOpt.exists { case d @ TlaOperDecl(_, _, body) => - !d.isRecursive && - (body match { - case LetInEx(ValEx(TlaInt(`defaultVal`)), TlaOperDecl(_, Nil, NameEx("p"))) => + newAOpt match { + case Some(d @ TlaOperDecl(_, _, body)) => + assert(!d.isRecursive) + body match { + case LetInEx(letBody, TlaOperDecl(_, Nil, declBody)) => + assert(tla.bigInt(defaultVal).typed() == letBody) + assert(tla.name("p").typed(IntT1()) == declBody) true + case _ => false - }) + } } - - assert(assertCond) } test("1 step: Nontrivial inlining") { - val name = "A" - + // prepare the inputs + val intToInt = OperT1(Seq(IntT1()), IntT1()) + val types = Map("b" -> BoolT1(), "i" -> IntT1(), "s" -> StrT1(), "T" -> intToInt) + val declarationName = "A" + + // A(p) + val aBody = tla + .appOp(tla.name(declarationName) ? "T", n_p ? "i") + .typed(types, "i") // A(p) == A(p) - val recDecl = tla.declOp(name, tla.appOp(n_A, n_p), "p").untypedOperDecl() + val recDecl = tla + .declOp(declarationName, aBody, "p") + .typedOperDecl(intToInt) recDecl.isRecursive = true val defaultVal: BigInt = 42 val decls = recDecl +: (Seq[(String, TlaEx)]( - (Unroller.UNROLL_TIMES_PREFIX + name, 1), - (Unroller.UNROLL_DEFAULT_PREFIX + name, defaultVal.intValue) + (Unroller.UNROLL_TIMES_PREFIX + declarationName, tla.int(1).typed(IntT1())), + (Unroller.UNROLL_DEFAULT_PREFIX + declarationName, tla.int(defaultVal.intValue).typed(IntT1())) ) map exAsDecl) val module = new TlaModule("M", decls) + // call the unroller that we are testing val unrolled = unroller(module) - val newAOpt = unrolled.operDeclarations.find(_.name == name) + // test the outputs + val newAOpt = unrolled.operDeclarations.find(_.name == declarationName) - val assertCond = newAOpt.exists { case d @ TlaOperDecl(_, _, body) => - !d.isRecursive && - (body match { + newAOpt match { + case Some(d @ TlaOperDecl(_, _, body)) => + assert(!d.isRecursive) + assert(Typed(IntT1()) == d.body.typeTag) + + body match { case LetInEx(paramNormalBody, TlaOperDecl(uniqueName, Nil, NameEx("p"))) => + assert(Typed(IntT1()) == paramNormalBody.typeTag) + paramNormalBody match { - case LetInEx(ValEx(TlaInt(`defaultVal`)), TlaOperDecl(_, Nil, OperEx(TlaOper.apply, NameEx( - `uniqueName`)))) => - true - case _ => false + case LetInEx(defaultBody, TlaOperDecl(_, Nil, OperEx(TlaOper.apply, NameEx(defaultName)))) => + assert(tla.bigInt(defaultVal).typed() == defaultBody) + assert(uniqueName == defaultName) + assert(Typed(IntT1()) == defaultBody.typeTag) + + case _ => + fail("Expected second LetInEx") } - case _ => false - }) - } - assert(assertCond) + case _ => + fail("Expected first LetInEx") + } + + case None => + fail("Expected Some(TlaOperDecl(...))") + } } test("Recursive LET-IN inside non-recursive operator") { + val intToInt = OperT1(Seq(IntT1()), IntT1()) + val types = Map("b" -> BoolT1(), "i" -> IntT1(), "s" -> StrT1(), "T" -> intToInt, "X" -> OperT1(Seq(), IntT1())) val letInOpName = "A" + val aBody = tla + .appOp(n_A ? "T", n_p ? "i") + .typed(types, "i") // A(p) == A(p) - val recDecl = tla.declOp(letInOpName, tla.appOp(n_A, n_p), "p").untypedOperDecl() + val recDecl = tla + .declOp(letInOpName, aBody, "p") + .typedOperDecl(intToInt) recDecl.isRecursive = true - val appEx = tla.appDecl(recDecl, tla.int(99)).untyped() + val appEx = tla + .appOp(tla.name("A") ? "T", tla.int(99) ? "i") + .typed(types, "i") // X == LET A(p) == A(p) IN A(99) - val nonRecDecl = tla.declOp("X", tla.letIn(appEx, recDecl)).untypedOperDecl() + val nonRecDecl = tla + .declOp("X", tla.letIn(appEx, recDecl).typed(types, "i")) + .typedOperDecl(types, "X") val defaultVal: BigInt = 42 val decls = nonRecDecl +: (Seq[(String, TlaEx)]( - (Unroller.UNROLL_TIMES_PREFIX + letInOpName, 1), - (Unroller.UNROLL_DEFAULT_PREFIX + letInOpName, defaultVal.intValue) + (Unroller.UNROLL_TIMES_PREFIX + letInOpName, tla.int(1).typed(IntT1())), + (Unroller.UNROLL_DEFAULT_PREFIX + letInOpName, tla.int(defaultVal.intValue).typed(IntT1())) ) map exAsDecl) val module = new TlaModule("M", decls) @@ -140,8 +194,8 @@ class TestUnroller extends FunSuite with BeforeAndAfterEach with TestingPredefs unroller = new Unroller(new UniqueNameGenerator, noTracker, new IncrementalRenaming(noTracker)) val altDecls = recDecl +: (Seq[(String, TlaEx)]( - (Unroller.UNROLL_TIMES_PREFIX + letInOpName, 1), - (Unroller.UNROLL_DEFAULT_PREFIX + letInOpName, defaultVal.intValue) + (Unroller.UNROLL_TIMES_PREFIX + letInOpName, tla.int(1).typed(IntT1())), + (Unroller.UNROLL_DEFAULT_PREFIX + letInOpName, tla.bigInt(defaultVal.intValue).typed(IntT1())) ) map exAsDecl) val altModule = new TlaModule("N", altDecls) diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/CachingType1Parser.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/CachingType1Parser.scala index 62ccc711fd..c1d9313405 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/CachingType1Parser.scala +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/CachingType1Parser.scala @@ -1,5 +1,7 @@ package at.forsyte.apalache.tla.typecheck +import at.forsyte.apalache.tla.lir.TlaType1 + import scala.collection.mutable /** diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/DefaultTypeCheckerListener.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/DefaultTypeCheckerListener.scala index cac88955ab..d3d8810ade 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/DefaultTypeCheckerListener.scala +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/DefaultTypeCheckerListener.scala @@ -1,5 +1,6 @@ package at.forsyte.apalache.tla.typecheck +import at.forsyte.apalache.tla.lir.TlaType1 import at.forsyte.apalache.tla.typecheck.etc.{EtcRef, ExactRef} /** diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/MultiTypeCheckerListener.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/MultiTypeCheckerListener.scala index 4fa07d6da0..ff8238d420 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/MultiTypeCheckerListener.scala +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/MultiTypeCheckerListener.scala @@ -1,5 +1,6 @@ package at.forsyte.apalache.tla.typecheck +import at.forsyte.apalache.tla.lir.TlaType1 import at.forsyte.apalache.tla.typecheck.etc.{EtcRef, ExactRef} /** diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/Type1Parser.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/Type1Parser.scala index c4979c7287..84bb4eae71 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/Type1Parser.scala +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/Type1Parser.scala @@ -1,17 +1,18 @@ package at.forsyte.apalache.tla.typecheck +import at.forsyte.apalache.tla.lir.TlaType1 + /** * A trait for a parser of TS1 types in the grammar of ADR-002: * *
- *   T ::= typeConst | typeVar | Bool | Int | Str | T -> T | Set(T) | Seq(T) |
- *         <<T, ..., T>> | [h_1: T, ..., h_k: T] | (T, ..., T) => T | (T)
- *   typeConst ::= <an identifier that matches [A-Z_][A-Z0-9_]*>
- *   typeVar ::= <a single letter from [a-z]>
+ * T ::= typeConst | typeVar | Bool | Int | Str | T -> T | Set(T) | Seq(T) |
+ * <<T, ..., T>> | [h_1: T, ..., h_k: T] | (T, ..., T) => T | (T)
+ * typeConst ::= <an identifier that matches [A-Z_][A-Z0-9_]*>
+ * typeVar ::= <a single letter from [a-z]>
  * 
* * @see at.forsyte.apalache.tla.typecheck.parser.DefaultType1Parser - * * @author Igor Konnov, 2020 */ trait Type1Parser { diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/TypeChecker.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/TypeChecker.scala index 212b8cebbb..e60b58b5c7 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/TypeChecker.scala +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/TypeChecker.scala @@ -1,5 +1,6 @@ package at.forsyte.apalache.tla.typecheck +import at.forsyte.apalache.tla.lir.TlaType1 import at.forsyte.apalache.tla.typecheck.etc.EtcExpr /** diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/TypeCheckerListener.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/TypeCheckerListener.scala index 1aab3e1748..e7a0ac8720 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/TypeCheckerListener.scala +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/TypeCheckerListener.scala @@ -1,5 +1,6 @@ package at.forsyte.apalache.tla.typecheck +import at.forsyte.apalache.tla.lir.TlaType1 import at.forsyte.apalache.tla.typecheck.etc.{EtcRef, ExactRef} /** diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/TypeCheckerTool.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/TypeCheckerTool.scala index 2b727fa7ec..6c7fa50956 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/TypeCheckerTool.scala +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/TypeCheckerTool.scala @@ -2,7 +2,7 @@ package at.forsyte.apalache.tla.typecheck import at.forsyte.apalache.io.annotations.store.AnnotationStore import at.forsyte.apalache.tla.lir.transformations.TransformationTracker -import at.forsyte.apalache.tla.lir.{TlaModule, TypeTag, UID} +import at.forsyte.apalache.tla.lir.{BoolT1, TlaModule, TypeTag, UID} import at.forsyte.apalache.tla.typecheck.etc._ import at.forsyte.apalache.tla.typecheck.integration.{RecordingTypeCheckerListener, TypeRewriter} diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/TypeContext.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/TypeContext.scala index 7fe04a23a4..415f0eeb41 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/TypeContext.scala +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/TypeContext.scala @@ -1,5 +1,7 @@ package at.forsyte.apalache.tla.typecheck +import at.forsyte.apalache.tla.lir.TlaType1 + import scala.collection.immutable.SortedMap /** diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/Clause.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/Clause.scala index c010439471..26fb902b98 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/Clause.scala +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/Clause.scala @@ -1,6 +1,6 @@ package at.forsyte.apalache.tla.typecheck.etc -import at.forsyte.apalache.tla.typecheck.{TlaType1, VarT1} +import at.forsyte.apalache.tla.lir.{TlaType1, VarT1} /** * A unification constraint for the unification solver. diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/ConstraintSolver.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/ConstraintSolver.scala index bf73019f76..bbd1644561 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/ConstraintSolver.scala +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/ConstraintSolver.scala @@ -1,6 +1,6 @@ package at.forsyte.apalache.tla.typecheck.etc -import at.forsyte.apalache.tla.typecheck.TlaType1 +import at.forsyte.apalache.tla.lir.TlaType1 /** * A constraint solver that collects a series of equations and solves them with the type unification algorithm. diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/EtcBuilder.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/EtcBuilder.scala index fea496d194..57339c3bd8 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/EtcBuilder.scala +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/EtcBuilder.scala @@ -1,6 +1,6 @@ package at.forsyte.apalache.tla.typecheck.etc -import at.forsyte.apalache.tla.lir.UID +import at.forsyte.apalache.tla.lir.{TlaType1, UID} import at.forsyte.apalache.tla.typecheck._ /** diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/EtcExpr.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/EtcExpr.scala index ec01fa64a4..62bf6a18b4 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/EtcExpr.scala +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/EtcExpr.scala @@ -1,6 +1,6 @@ package at.forsyte.apalache.tla.typecheck.etc -import at.forsyte.apalache.tla.typecheck.TlaType1 +import at.forsyte.apalache.tla.lir.TlaType1 /** * An expression in a simple typed lambda calculus. Here we do not care about the concrete values, diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/EtcTypeChecker.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/EtcTypeChecker.scala index a2fb40182c..c60c6252e7 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/EtcTypeChecker.scala +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/EtcTypeChecker.scala @@ -1,5 +1,7 @@ package at.forsyte.apalache.tla.typecheck.etc +import at.forsyte.apalache.tla.lir +import at.forsyte.apalache.tla.lir.{OperT1, SetT1, TlaType1, TypingException, VarT1} import at.forsyte.apalache.tla.typecheck._ import at.forsyte.apalache.tla.typecheck.etc.EtcTypeChecker.UnwindException @@ -134,7 +136,7 @@ class EtcTypeChecker(varPool: TypeVarPool, inferPolytypes: Boolean = false) exte } // operVar = (arg_1, ..., arg_k) => resVar - solver.addConstraint(EqClause(operVar, OperT1(argTypes, resVar)) + solver.addConstraint(EqClause(operVar, lir.OperT1(argTypes, resVar)) .setOnTypeFound(onFound) .setOnTypeError(onArgsMatchError)) @@ -210,7 +212,7 @@ class EtcTypeChecker(varPool: TypeVarPool, inferPolytypes: Boolean = false) exte case None => // Let the solver compute the type. If it fails, the user has to annotate the definition - OperT1(1.to(binders.length).map(_ => varPool.fresh), varPool.fresh) + lir.OperT1(1.to(binders.length).map(_ => varPool.fresh), varPool.fresh) } // translate the binders in the lambda expression, so we can quickly propagate the types of the parameters diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/Substitution.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/Substitution.scala index 37d51c9ebe..d48f059238 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/Substitution.scala +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/Substitution.scala @@ -1,5 +1,8 @@ package at.forsyte.apalache.tla.typecheck.etc +import at.forsyte.apalache.tla.lir.{ + BoolT1, ConstT1, FunT1, IntT1, OperT1, RealT1, RecT1, SeqT1, SetT1, SparseTupT1, StrT1, TlaType1, TupT1, VarT1 +} import at.forsyte.apalache.tla.typecheck._ /** diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/ToEtcExpr.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/ToEtcExpr.scala index 72eaacb711..43c1fd3c07 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/ToEtcExpr.scala +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/ToEtcExpr.scala @@ -70,8 +70,9 @@ class ToEtcExpr(annotationStore: AnnotationStore, varPool: TypeVarPool) extends findTypeFromTagOrAnnotation(decl) match { case Some(tt) => // case 1: the definition is either annotated with a java-like annotation in a comment, or tagged with TlaType1 + val fixedType = fixLazyAnnotation(decl, tt) val letAbs = mkLetAbs(decl.ID, this(decl.body), paramsAndDoms: _*) - mkTypeDecl(ExactRef(decl.ID), decl.name, tt, letAbs) + mkTypeDecl(ExactRef(decl.ID), decl.name, fixedType, letAbs) case None => // case 2: no type annotation @@ -79,6 +80,15 @@ class ToEtcExpr(annotationStore: AnnotationStore, varPool: TypeVarPool) extends } } + // We let the user to write a instead of () => a for nullary operators. This method fixes such a lazy annotation. + private def fixLazyAnnotation(decl: TlaOperDecl, tt: TlaType1): TlaType1 = { + if (decl.formalParams.isEmpty && !tt.isInstanceOf[OperT1]) { + OperT1(Seq(), tt) + } else { + tt + } + } + // parse type from its text representation private def parseType(where: String, text: String): TlaType1 = { try { @@ -825,11 +835,16 @@ class ToEtcExpr(annotationStore: AnnotationStore, varPool: TypeVarPool) extends mkExRefApp(opsig, args) //******************************************** MISC ************************************************** - case OperEx(BmcOper.withType, lhs, _) => + case OperEx(TlaOper.label, labelledEx, nameAndArgs @ _*) => + val typeVar = varPool.fresh + mkExRefApp(OperT1(nameAndArgs.map(_ => StrT1()) :+ typeVar, typeVar), nameAndArgs :+ labelledEx) + + case OperEx(BmcOper.withType, lhs, annotation) => // Met an old type annotation. Warn the user and ignore the annotation. - logger.warn("Met an old type annotation. Ignored: " + ex) - logger.warn("See: https://apalache.informal.systems/docs/apalache/typechecker-snowcat.html") - this(lhs) + logger.error("Met an old type annotation: " + annotation) + logger.error("See: https://apalache.informal.systems/docs/apalache/typechecker-snowcat.html") + val msg = s"Old Apalache type annotations (predating 0.12.0) are no longer supported" + throw new OutdatedAnnotationsError(msg, ex) //********************************************* TLC ************************************************** case OperEx(TlcOper.print, text, value) => diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/TypeUnifier.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/TypeUnifier.scala index beacfc8c86..79a759f6c7 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/TypeUnifier.scala +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/TypeUnifier.scala @@ -1,5 +1,8 @@ package at.forsyte.apalache.tla.typecheck.etc +import at.forsyte.apalache.tla.lir.{ + BoolT1, ConstT1, FunT1, IntT1, OperT1, RealT1, RecT1, SeqT1, SetT1, SparseTupT1, StrT1, TlaType1, TupT1, VarT1 +} import at.forsyte.apalache.tla.typecheck._ import at.forsyte.apalache.tla.typecheck.etc.TypeUnifier.CycleDetected diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/TypeVarPool.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/TypeVarPool.scala index 802ee39450..c545a5c029 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/TypeVarPool.scala +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/etc/TypeVarPool.scala @@ -1,6 +1,6 @@ package at.forsyte.apalache.tla.typecheck.etc -import at.forsyte.apalache.tla.typecheck.VarT1 +import at.forsyte.apalache.tla.lir.VarT1 class TypeVarPool(start: Int = 0) { diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/exceptions.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/exceptions.scala index 2d2c7a2fe2..5cc1bb8226 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/exceptions.scala +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/exceptions.scala @@ -6,12 +6,7 @@ package at.forsyte.apalache.tla.typecheck -/** - * This exception is thrown, whenever the type checker finds an irrecoverable error. - * - * @author konnov - */ -class TypingException(message: String) extends Exception(message) +import at.forsyte.apalache.tla.lir.TypingException /** * This exception is thrown, whenever the type checker finds an irrecoverable error in the user input. diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/integration/RecordingTypeCheckerListener.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/integration/RecordingTypeCheckerListener.scala index f33aa7912f..5fda015111 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/integration/RecordingTypeCheckerListener.scala +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/integration/RecordingTypeCheckerListener.scala @@ -1,8 +1,8 @@ package at.forsyte.apalache.tla.typecheck.integration -import at.forsyte.apalache.tla.lir.UID +import at.forsyte.apalache.tla.lir.{TlaType1, UID} import at.forsyte.apalache.tla.typecheck.etc.{EtcRef, ExactRef} -import at.forsyte.apalache.tla.typecheck.{TlaType1, TypeCheckerListener} +import at.forsyte.apalache.tla.typecheck.TypeCheckerListener import scala.collection.mutable diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/integration/TypeRewriter.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/integration/TypeRewriter.scala index b22f9af26f..97006deec8 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/integration/TypeRewriter.scala +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/integration/TypeRewriter.scala @@ -2,9 +2,8 @@ package at.forsyte.apalache.tla.typecheck.integration import at.forsyte.apalache.tla.lir.transformations.TransformationTracker import at.forsyte.apalache.tla.lir._ -import at.forsyte.apalache.tla.lir.oper.TlaFunOper +import at.forsyte.apalache.tla.lir.oper.{BmcOper, TlaFunOper} import at.forsyte.apalache.tla.lir.values.TlaStr -import at.forsyte.apalache.tla.typecheck.{StrT1, TlaType1, TupT1, TypingException} /** * This class uses the map of types to set the types of TLA+ expressions and declarations. @@ -61,6 +60,10 @@ class TypeRewriter(tracker: TransformationTracker, defaultTag: UID => TypeTag)(t OperEx(TlaFunOper.except, taggedFun +: accessorsWithTaggedValues: _*)(getOrDefault(ex.ID)) + case ex @ OperEx(BmcOper.withType, lhs, annotation) => + // an old type annotation: transform the left-hand side, as the type checker does not understand the old type annotations + OperEx(BmcOper.withType, transform(lhs), annotation)(getOrDefault(lhs.ID)) + case ex @ OperEx(oper, args @ _*) => val newArgs = args.map(this(_)) OperEx(oper, newArgs: _*)(getOrDefault(ex.ID)) diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/integration/TypeWatchdogTransformationListener.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/integration/TypeWatchdogTransformationListener.scala new file mode 100644 index 0000000000..e733bc1771 --- /dev/null +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/integration/TypeWatchdogTransformationListener.scala @@ -0,0 +1,31 @@ +package at.forsyte.apalache.tla.typecheck.integration + +import at.forsyte.apalache.tla.lir.transformations.TransformationListener +import at.forsyte.apalache.tla.lir.{TlaDecl, TlaEx, TlaType1, Typed, TypingException, Untyped} + +/** + * A transformation tracker that throws an exception, if a typed expression has been transformed into an untyped one. + * + * @author Igor Konnov + */ +class TypeWatchdogTransformationListener extends TransformationListener { + override def onTransformation(originalEx: TlaEx, newEx: TlaEx): Unit = { + (originalEx.typeTag, newEx.typeTag) match { + case (Typed(_: TlaType1), Untyped()) => + throw new TypingException( + s"A typed expression ${originalEx.ID} was transformed to an untyped expression ${newEx.ID}") + + case _ => () + } + } + + override def onDeclTransformation(originalDecl: TlaDecl, newDecl: TlaDecl): Unit = { + (originalDecl.typeTag, newDecl.typeTag) match { + case (Typed(_: TlaType1), Untyped()) => + throw new TypingException( + s"A typed declaration ${originalDecl.name} was transformed to an untyped expression ${newDecl.name}") + + case _ => () + } + } +} diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/parser/DefaultType1Parser.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/parser/DefaultType1Parser.scala index 8cf1d5a9fc..a732bfc08e 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/parser/DefaultType1Parser.scala +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/parser/DefaultType1Parser.scala @@ -1,7 +1,10 @@ package at.forsyte.apalache.tla.typecheck.parser -import java.io.{Reader, StringReader} +import at.forsyte.apalache.tla.lir.{ + BoolT1, ConstT1, FunT1, IntT1, OperT1, RealT1, RecT1, SeqT1, SetT1, SparseTupT1, StrT1, TlaType1, TupT1, VarT1 +} +import java.io.{Reader, StringReader} import at.forsyte.apalache.tla.typecheck._ import scala.collection.immutable.SortedMap diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/passes/EtcTypeCheckerAdapter.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/passes/EtcTypeCheckerAdapter.scala index 6609735f39..4a663c6cfd 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/passes/EtcTypeCheckerAdapter.scala +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/passes/EtcTypeCheckerAdapter.scala @@ -2,8 +2,12 @@ package at.forsyte.apalache.tla.typecheck.passes import at.forsyte.apalache.infra.{ErrorMessage, ExceptionAdapter, FailureMessage, NormalErrorMessage} import at.forsyte.apalache.tla.imp.SanyException -import at.forsyte.apalache.tla.typecheck.{TypingException, TypingInputException} +import at.forsyte.apalache.tla.imp.src.SourceStore +import at.forsyte.apalache.tla.lir.storage.{ChangeListener, SourceLocator} +import at.forsyte.apalache.tla.lir.{OutdatedAnnotationsError, TypingException, UID} +import at.forsyte.apalache.tla.typecheck.TypingInputException import com.google.inject.{Inject, Singleton} +import com.typesafe.scalalogging.LazyLogging /** * The adapter for the exceptions that are produced by the parser and type checker. @@ -11,7 +15,8 @@ import com.google.inject.{Inject, Singleton} * @author Igor Konnov */ @Singleton -class EtcTypeCheckerAdapter @Inject() () extends ExceptionAdapter { +class EtcTypeCheckerAdapter @Inject() (sourceStore: SourceStore, changeListener: ChangeListener) + extends ExceptionAdapter with LazyLogging { override def toMessage: PartialFunction[Exception, ErrorMessage] = { case err: SanyException => NormalErrorMessage("Error by TLA+ parser: " + err.getMessage) @@ -19,7 +24,20 @@ class EtcTypeCheckerAdapter @Inject() () extends ExceptionAdapter { case err: TypingInputException => NormalErrorMessage("Typing input error: " + err.getMessage) + case err: OutdatedAnnotationsError => + val msg = "%s: rewriter error: %s".format(findLoc(err.causeExpr.ID), err.getMessage) + NormalErrorMessage(msg) + case err: TypingException => FailureMessage("Type checker error: " + err.getMessage) } + + private def findLoc(id: UID): String = { + val sourceLocator: SourceLocator = SourceLocator(sourceStore.makeSourceMap, changeListener) + + sourceLocator.sourceOf(id) match { + case Some(loc) => loc.toString + case None => "" + } + } } diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/passes/EtcTypeCheckerPassImpl.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/passes/EtcTypeCheckerPassImpl.scala index 3d0f589274..0fd0a9f11f 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/passes/EtcTypeCheckerPassImpl.scala +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/passes/EtcTypeCheckerPassImpl.scala @@ -23,6 +23,8 @@ class EtcTypeCheckerPassImpl @Inject() (val options: PassOptions, val sourceStor private var outputTlaModule: Option[TlaModule] = None + protected def inferPoly: Boolean = options.getOrElse("typecheck", "inferPoly", true) + /** * The name of the pass * @@ -47,7 +49,6 @@ class EtcTypeCheckerPassImpl @Inject() (val options: PassOptions, val sourceStor logger.info(" > Running Snowcat .::.") dumpToJson(tlaModule.get, "pre") - val inferPoly = options.getOrElse("typecheck", "inferPoly", true) val tool = new TypeCheckerTool(annotationStore, inferPoly) // when this flag is true by the end of type checking, we have recovered the types of all expressions diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/passes/LoggingTypeCheckerListener.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/passes/LoggingTypeCheckerListener.scala index 328169ae67..35cbbcb4a5 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/passes/LoggingTypeCheckerListener.scala +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/passes/LoggingTypeCheckerListener.scala @@ -1,9 +1,9 @@ package at.forsyte.apalache.tla.typecheck.passes import at.forsyte.apalache.tla.imp.src.SourceStore -import at.forsyte.apalache.tla.lir.UID +import at.forsyte.apalache.tla.lir.{TlaType1, UID} import at.forsyte.apalache.tla.lir.storage.{ChangeListener, SourceLocator} -import at.forsyte.apalache.tla.typecheck.{TlaType1, TypeCheckerListener} +import at.forsyte.apalache.tla.typecheck.TypeCheckerListener import at.forsyte.apalache.tla.typecheck.etc.{EtcRef, ExactRef} import com.typesafe.scalalogging.LazyLogging diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/types/TlaType1TagPrinter.scala b/tla-types/src/main/scala/at/forsyte/apalache/tla/types/TlaType1TagPrinter.scala index 1b45137b41..5417039963 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/types/TlaType1TagPrinter.scala +++ b/tla-types/src/main/scala/at/forsyte/apalache/tla/types/TlaType1TagPrinter.scala @@ -1,8 +1,7 @@ package at.forsyte.apalache.tla.types -import at.forsyte.apalache.tla.lir.{TypeTag, Typed, Untyped} +import at.forsyte.apalache.tla.lir.{TlaType1, TypeTag, Typed, Untyped} import at.forsyte.apalache.tla.lir.io.TypeTagPrinter -import at.forsyte.apalache.tla.typecheck.TlaType1 class TlaType1TagPrinter extends TypeTagPrinter { def apply(tag: TypeTag): String = { diff --git a/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/TestCachingType1Parser.scala b/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/TestCachingType1Parser.scala index 2d53df79ca..cf5b0b6e2d 100644 --- a/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/TestCachingType1Parser.scala +++ b/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/TestCachingType1Parser.scala @@ -1,5 +1,6 @@ package at.forsyte.apalache.tla.typecheck +import at.forsyte.apalache.tla.lir.IntT1 import org.junit.runner.RunWith import org.scalatest.FunSuite import org.scalatest.junit.JUnitRunner diff --git a/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/TestDefaultType1Parser.scala b/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/TestDefaultType1Parser.scala index c190d09df0..d8c3573da9 100644 --- a/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/TestDefaultType1Parser.scala +++ b/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/TestDefaultType1Parser.scala @@ -1,5 +1,8 @@ package at.forsyte.apalache.tla.typecheck +import at.forsyte.apalache.tla.lir.{ + BoolT1, ConstT1, FunT1, IntT1, OperT1, RealT1, RecT1, SeqT1, SetT1, SparseTupT1, StrT1, TupT1, VarT1 +} import at.forsyte.apalache.tla.typecheck.parser.{DefaultType1Parser, Type1ParseError} import org.junit.runner.RunWith import org.scalacheck.Gen.alphaStr diff --git a/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/TestTlaType1.scala b/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/TestTlaType1.scala index ec5f12753d..56e4d5a70a 100644 --- a/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/TestTlaType1.scala +++ b/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/TestTlaType1.scala @@ -1,5 +1,8 @@ package at.forsyte.apalache.tla.typecheck +import at.forsyte.apalache.tla.lir.{ + BoolT1, ConstT1, FunT1, IntT1, OperT1, RealT1, RecT1, SeqT1, SetT1, StrT1, TupT1, VarT1 +} import org.junit.runner.RunWith import org.scalatest.FunSuite import org.scalatest.junit.JUnitRunner diff --git a/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/TlaType1Gen.scala b/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/TlaType1Gen.scala index 620b823453..754c7d66fd 100644 --- a/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/TlaType1Gen.scala +++ b/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/TlaType1Gen.scala @@ -1,5 +1,8 @@ package at.forsyte.apalache.tla.typecheck +import at.forsyte.apalache.tla.lir.{ + BoolT1, ConstT1, FunT1, IntT1, OperT1, RealT1, RecT1, SeqT1, SetT1, StrT1, TlaType1, TupT1, VarT1 +} import org.scalacheck.Arbitrary.arbitrary import org.scalacheck.Gen import org.scalacheck.Gen.{choose, const, identifier, listOfN, lzy, oneOf, posNum, resize, sized} diff --git a/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/etc/TestConstraintSolver.scala b/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/etc/TestConstraintSolver.scala index d79ecb2f6c..d413841928 100644 --- a/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/etc/TestConstraintSolver.scala +++ b/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/etc/TestConstraintSolver.scala @@ -1,5 +1,6 @@ package at.forsyte.apalache.tla.typecheck.etc +import at.forsyte.apalache.tla.lir.{OperT1, VarT1} import at.forsyte.apalache.tla.typecheck._ import at.forsyte.apalache.tla.typecheck.parser.DefaultType1Parser import org.junit.runner.RunWith diff --git a/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/etc/TestEtcTypeChecker.scala b/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/etc/TestEtcTypeChecker.scala index e22f9db5f4..4c19b9e53a 100644 --- a/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/etc/TestEtcTypeChecker.scala +++ b/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/etc/TestEtcTypeChecker.scala @@ -1,5 +1,6 @@ package at.forsyte.apalache.tla.typecheck.etc +import at.forsyte.apalache.tla.lir.{BoolT1, IntT1, OperT1, SeqT1, SetT1, TlaType1, TupT1, VarT1} import at.forsyte.apalache.tla.typecheck._ import at.forsyte.apalache.tla.typecheck.parser.DefaultType1Parser import org.easymock.EasyMock diff --git a/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/etc/TestToEtcExpr.scala b/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/etc/TestToEtcExpr.scala index 3149c4961e..d69927d9d7 100644 --- a/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/etc/TestToEtcExpr.scala +++ b/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/etc/TestToEtcExpr.scala @@ -700,6 +700,14 @@ class TestToEtcExpr extends FunSuite with BeforeAndAfterEach with EtcBuilder { assert(expected == gen(ex)) } + test("Labels") { + val typ = parser("(Str, Str, Str, a) => a") + val expected = + mkUniqApp(Seq(typ), mkUniqConst(StrT1()), mkUniqConst(StrT1()), mkUniqConst(StrT1()), mkUniqName("x")) + val ex = tla.label(tla.name("x"), "lab", "a", "b") + assert(expected == gen(ex)) + } + test("Apalache!FunAsSeq(fun, len)") { val typ = parser("(Int -> a, Int) => Seq(a)") val expected = mkAppByName(Seq(typ), "fun", "len") @@ -766,7 +774,7 @@ class TestToEtcExpr extends FunSuite with BeforeAndAfterEach with EtcBuilder { test("old annotations: e <: tp") { val oldTypeAnnotation = tla.enumSet(tla.intSet()) val input = tla.withType(tla.name("e"), oldTypeAnnotation) - assert(mkUniqName("e") == gen(input)) + assertThrows[OutdatedAnnotationsError](gen(input)) } test("TLC!Print") { diff --git a/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/etc/TestTypeCheckerTool.scala b/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/etc/TestTypeCheckerTool.scala index a80c3abaed..fabc82f021 100644 --- a/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/etc/TestTypeCheckerTool.scala +++ b/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/etc/TestTypeCheckerTool.scala @@ -2,9 +2,9 @@ package at.forsyte.apalache.tla.typecheck.etc import at.forsyte.apalache.tla.imp.SanyImporter import at.forsyte.apalache.tla.imp.src.SourceStore -import at.forsyte.apalache.tla.typecheck.{TlaType1, Type1Parser, TypeCheckerListener, TypeCheckerTool, TypingException} +import at.forsyte.apalache.tla.typecheck.{Type1Parser, TypeCheckerListener, TypeCheckerTool} import at.forsyte.apalache.io.annotations.store._ -import at.forsyte.apalache.tla.lir.{Typed, UID} +import at.forsyte.apalache.tla.lir.{TlaType1, Typed, TypingException, UID} import at.forsyte.apalache.tla.lir.transformations.impl.IdleTracker import at.forsyte.apalache.tla.typecheck.parser.DefaultType1Parser import org.easymock.EasyMock diff --git a/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/etc/TestTypeUnifier.scala b/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/etc/TestTypeUnifier.scala index 2fa33b7cf0..2238891fdb 100644 --- a/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/etc/TestTypeUnifier.scala +++ b/tla-types/src/test/scala/at/forsyte/apalache/tla/typecheck/etc/TestTypeUnifier.scala @@ -1,5 +1,6 @@ package at.forsyte.apalache.tla.typecheck.etc +import at.forsyte.apalache.tla.lir.{BoolT1, ConstT1, FunT1, IntT1, OperT1, RealT1, SetT1, StrT1, VarT1} import at.forsyte.apalache.tla.typecheck.parser.DefaultType1Parser import at.forsyte.apalache.tla.typecheck._ import org.junit.runner.RunWith diff --git a/tlair/src/main/scala/at/forsyte/apalache/tla/lir/Builder.scala b/tlair/src/main/scala/at/forsyte/apalache/tla/lir/Builder.scala index 8eeca79dcb..c40ad4e5a0 100644 --- a/tlair/src/main/scala/at/forsyte/apalache/tla/lir/Builder.scala +++ b/tlair/src/main/scala/at/forsyte/apalache/tla/lir/Builder.scala @@ -140,7 +140,7 @@ class Builder { * * @return the value expression that corresponds to Nat. */ - def natSet(): BuilderVal = BuilderVal(TlaIntSet) + def natSet(): BuilderVal = BuilderVal(TlaNatSet) /** Declarations */ @@ -565,8 +565,32 @@ class Builder { in(prime(nameToPrime), enumSet(onlySetElem)) } - // d_1 :> e_1 @@ ... @@ d_k :> e_k - def atat(args: BuilderEx*): BuilderEx = { + /** + * The TLC operator that creates a singleton function: key :> value. + */ + def colonGreater(key: BuilderEx, value: BuilderEx): BuilderEx = { + BuilderOper(TlcOper.colonGreater, key, value) + } + + /** + * The TLC operator that concatenates two functions: fun1 @@ fun2. + * + * @param lhs function on the left-hand side + * @param rhs function on the right-hand side + * @return a new function that operates on the joint domain of lhs and rhs + */ + def atat(lhs: BuilderEx, rhs: BuilderEx): BuilderEx = { + BuilderOper(TlcOper.atat, lhs, rhs) + } + + /** + * Produce a function out of a sequence of keys and values, that is, key_1 :> value_1 @@ ... @@ key_k :> value_k. + * + * TODO: this method introduces an intermediate builder expression, so it cannot be used to construct a typed expression. + * + * @param args an alternating list of keys and values + */ + def atatInterleaved(args: BuilderEx*): BuilderEx = { if (args.isEmpty) { BuilderOper(TlcOper.atat) } else { @@ -577,6 +601,7 @@ class Builder { } // apalache operators + @deprecated("This operator introduces an old-style apalache annotation. It will be removed soon.") def withType(expr: BuilderEx, typeAnnot: BuilderEx): BuilderEx = { BuilderOper(BmcOper.withType, expr, typeAnnot) } diff --git a/tlair/src/main/scala/at/forsyte/apalache/tla/lir/TestingPredefs.scala b/tlair/src/main/scala/at/forsyte/apalache/tla/lir/TestingPredefs.scala index 19a453c252..34ff4d208e 100644 --- a/tlair/src/main/scala/at/forsyte/apalache/tla/lir/TestingPredefs.scala +++ b/tlair/src/main/scala/at/forsyte/apalache/tla/lir/TestingPredefs.scala @@ -8,8 +8,6 @@ import at.forsyte.apalache.tla.lir.UntypedPredefs._ trait TestingPredefs { implicit def name(p_s: String): NameEx = NameEx(p_s) - implicit def value(p_n: Int): ValEx = ValEx(TlaInt(p_n)) - implicit def sfp(p_s: String): SimpleFormalParam = SimpleFormalParam(p_s) implicit def ofp(p_pair: (String, Int)): OperFormalParam = diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/TlaType1.scala b/tlair/src/main/scala/at/forsyte/apalache/tla/lir/TlaType1.scala similarity index 96% rename from tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/TlaType1.scala rename to tlair/src/main/scala/at/forsyte/apalache/tla/lir/TlaType1.scala index f7ce9363da..51b3e10014 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/TlaType1.scala +++ b/tlair/src/main/scala/at/forsyte/apalache/tla/lir/TlaType1.scala @@ -1,4 +1,4 @@ -package at.forsyte.apalache.tla.typecheck +package at.forsyte.apalache.tla.lir import scala.collection.immutable.SortedMap @@ -16,6 +16,15 @@ sealed trait TlaType1 { def usedNames: Set[Int] } +object TlaType1 { + def fromTypeTag(typeTag: TypeTag): TlaType1 = { + typeTag match { + case Typed(tt: TlaType1) => tt + case _ => throw new TypingException("Expected Typed(_: TlaType1), found: " + typeTag) + } + } +} + /** * An integer type. */ diff --git a/tlair/src/main/scala/at/forsyte/apalache/tla/lir/TypeTag.scala b/tlair/src/main/scala/at/forsyte/apalache/tla/lir/TypeTag.scala index 18af933764..1371f1d1b3 100644 --- a/tlair/src/main/scala/at/forsyte/apalache/tla/lir/TypeTag.scala +++ b/tlair/src/main/scala/at/forsyte/apalache/tla/lir/TypeTag.scala @@ -37,4 +37,14 @@ trait TypeTagged[T] { * @return a shallow copy of TLA+ expression with the type tag set to newTypeTag */ def withTag(newTypeTag: TypeTag): T + + /** + * Object equality combined with type tag equality. + * + * @param other another object to compare with + * @return true, if `this == other && this.typeTag == other.typeTag` + */ + def eqTyped(other: TypeTagged[T]): Boolean = { + this == other && typeTag == other.typeTag + } } diff --git a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/TypedPredefs.scala b/tlair/src/main/scala/at/forsyte/apalache/tla/lir/TypedPredefs.scala similarity index 93% rename from tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/TypedPredefs.scala rename to tlair/src/main/scala/at/forsyte/apalache/tla/lir/TypedPredefs.scala index 7bfdd0443f..707d510669 100644 --- a/tla-types/src/main/scala/at/forsyte/apalache/tla/typecheck/TypedPredefs.scala +++ b/tlair/src/main/scala/at/forsyte/apalache/tla/lir/TypedPredefs.scala @@ -1,7 +1,6 @@ -package at.forsyte.apalache.tla.typecheck +package at.forsyte.apalache.tla.lir import at.forsyte.apalache.tla.lir.values._ -import at.forsyte.apalache.tla.lir._ object TypedPredefs { type Tag = Typed[TlaType1] @@ -29,6 +28,8 @@ object TypedPredefs { /** * An implicit wrapper around TypeTag that extract the type out of Typed(_: TlaType1). * + * TODO: shall we remove this implicit conversion in favor of TlaType1.fromTypeTag? + * * @param tag a type tag */ implicit class TypeTagAsTlaType1(tag: TypeTag) { @@ -70,6 +71,10 @@ object TypedPredefs { def typedOperDecl(topType: TlaType1): TlaOperDecl = { BuilderDeclAsTyped(block).typed(topType).asInstanceOf[TlaOperDecl] } + + def typedOperDecl(types: Map[String, TlaType1], alias: String): TlaOperDecl = { + BuilderDeclAsTyped(block).typed(types, alias).asInstanceOf[TlaOperDecl] + } } implicit class BuilderExAsTyped(block: BuilderEx) { diff --git a/tlair/src/main/scala/at/forsyte/apalache/tla/lir/exceptions.scala b/tlair/src/main/scala/at/forsyte/apalache/tla/lir/exceptions.scala index c2d23b3c09..d2ba027895 100644 --- a/tlair/src/main/scala/at/forsyte/apalache/tla/lir/exceptions.scala +++ b/tlair/src/main/scala/at/forsyte/apalache/tla/lir/exceptions.scala @@ -38,3 +38,18 @@ class BuilderError(message: String) extends LirError(message) * @param message the error message */ class CyclicDependencyError(message: String) extends LirError(message) + +/** + * This exception is thrown, whenever the code finds an irrecoverable error in expression types. + * + * @author konnov + */ +class TypingException(message: String) extends Exception(message) + +/** + * This exception is thrown when an outdated type annotation (pre 0.12.0) is met. + * + * @param message the error message + * @param causeExpr the expression that caused the error + */ +class OutdatedAnnotationsError(message: String, val causeExpr: TlaEx) extends LirError(message) diff --git a/tlair/src/main/scala/at/forsyte/apalache/tla/lir/transformations/standard/InlinerOfUserOper.scala b/tlair/src/main/scala/at/forsyte/apalache/tla/lir/transformations/standard/InlinerOfUserOper.scala index 2d686a86ee..714d6db9d1 100644 --- a/tlair/src/main/scala/at/forsyte/apalache/tla/lir/transformations/standard/InlinerOfUserOper.scala +++ b/tlair/src/main/scala/at/forsyte/apalache/tla/lir/transformations/standard/InlinerOfUserOper.scala @@ -5,7 +5,6 @@ import at.forsyte.apalache.tla.lir.oper.TlaOper import at.forsyte.apalache.tla.lir.storage.BodyMap import at.forsyte.apalache.tla.lir.transformations.standard.InlinerOfUserOper.kStepParameters import at.forsyte.apalache.tla.lir.transformations.{TlaExTransformation, TransformationTracker} -import at.forsyte.apalache.tla.lir.UntypedPredefs._ /** *

Attempts to instantiate the body of the operator named `name` with the provided arguments. @@ -75,7 +74,7 @@ class InlinerOfUserOper(defBodyMap: BodyMap, tracker: TransformationTracker) ext val bodyCopy = postTr(DeepCopy(tracker).deepCopyEx(decl.body)) val newBody = decl.formalParams.zip(args).foldLeft(bodyCopy) { case (b, (fParam, arg)) => - ReplaceFixed(tracker)(NameEx(fParam.name), arg)(b) + ReplaceFixed(tracker)(NameEx(fParam.name)(arg.typeTag), arg)(b) } // the step limit, if it was defined, decreases by 1 diff --git a/tlair/src/main/scala/at/forsyte/apalache/tla/lir/transformations/standard/LetInExpander.scala b/tlair/src/main/scala/at/forsyte/apalache/tla/lir/transformations/standard/LetInExpander.scala index 444089771d..427c86eeb6 100644 --- a/tlair/src/main/scala/at/forsyte/apalache/tla/lir/transformations/standard/LetInExpander.scala +++ b/tlair/src/main/scala/at/forsyte/apalache/tla/lir/transformations/standard/LetInExpander.scala @@ -4,7 +4,6 @@ import at.forsyte.apalache.tla.lir.transformations.{TlaExTransformation, Transfo import at.forsyte.apalache.tla.lir._ import at.forsyte.apalache.tla.lir.oper.TlaOper import at.forsyte.apalache.tla.lir.storage.BodyMapFactory -import at.forsyte.apalache.tla.lir.UntypedPredefs._ /** *

A transformation which replaces all occurrences of LET-IN expressions with @@ -58,7 +57,7 @@ class LetInExpander(tracker: TransformationTracker, keepNullary: Boolean) extend params.zip(args).foldLeft(lambdaBody) { // replace every parameter with the respective argument case (expr, (param, arg)) => - ReplaceFixed(tracker)(NameEx(param.name), arg)(expr) + ReplaceFixed(tracker)(NameEx(param.name)(arg.typeTag), arg)(expr) } // recursive processing of composite operators diff --git a/tlair/src/test/scala/at/forsyte/apalache/tla/lir/io/TestPrettyWriter.scala b/tlair/src/test/scala/at/forsyte/apalache/tla/lir/io/TestPrettyWriter.scala index 79e36e46b6..99c41c3a20 100644 --- a/tlair/src/test/scala/at/forsyte/apalache/tla/lir/io/TestPrettyWriter.scala +++ b/tlair/src/test/scala/at/forsyte/apalache/tla/lir/io/TestPrettyWriter.scala @@ -426,7 +426,7 @@ class TestPrettyWriter extends FunSuite with BeforeAndAfterEach { test("TLC @@") { val writer = new PrettyWriter(printWriter, 40) - val expr = atat(str("a"), int(1), str("b"), int(2), str("c"), int(3)) + val expr = atatInterleaved(str("a"), int(1), str("b"), int(2), str("c"), int(3)) writer.write(expr) printWriter.flush() val expected = """"a" :> 1 @@ "b" :> 2 @@ "c" :> 3""".stripMargin diff --git a/tlair/src/test/scala/at/forsyte/apalache/tla/lir/transformations/standard/TestInlinerofUserOper.scala b/tlair/src/test/scala/at/forsyte/apalache/tla/lir/transformations/standard/TestInlinerofUserOper.scala deleted file mode 100644 index 42327a4473..0000000000 --- a/tlair/src/test/scala/at/forsyte/apalache/tla/lir/transformations/standard/TestInlinerofUserOper.scala +++ /dev/null @@ -1,66 +0,0 @@ -package at.forsyte.apalache.tla.lir.transformations.standard - -import at.forsyte.apalache.tla.lir.{NameEx, SimpleFormalParam, TestingPredefs} -import at.forsyte.apalache.tla.lir.storage.BodyMapFactory -import at.forsyte.apalache.tla.lir.transformations.impl.IdleTracker -import at.forsyte.apalache.tla.lir.convenience.tla -import at.forsyte.apalache.tla.lir.UntypedPredefs._ -import org.junit.runner.RunWith -import org.scalatest.FunSuite -import org.scalatest.junit.JUnitRunner -import org.scalatest.prop.Checkers - -@RunWith(classOf[JUnitRunner]) -class TestInlinerofUserOper extends FunSuite with TestingPredefs with Checkers { - - import tla._ - - test("Test Inline") { - val cDecl = declOp("C", plus(n_x, int(1)), SimpleFormalParam("x")).untypedOperDecl() - val operDecls = Seq( - declOp("A", appOp(n_B)).untypedOperDecl(), - declOp("B", n_c).untypedOperDecl(), - cDecl - ) - - val bodies = BodyMapFactory.makeFromDecls(operDecls) - - val transformation = InlinerOfUserOper(bodies, new IdleTracker()) - - val ex1 = n_B - val ex2 = appOp(n_B).untyped() - val ex3 = n_A - val ex4 = appOp(n_A).untyped() - val ex5 = or(eql(int(1), int(0)), ge(appDecl(cDecl, appOp(n_A)), int(0))).untyped() - val ex6 = letIn( - appOp(NameEx("X")), - declOp("X", appOp(NameEx("C"), n_p)).untypedOperDecl() - ).untyped() - - val expected1 = n_B // Operators need to be called with apply - val expected2 = n_c - val expected3 = n_A // Operators need to be called with apply - val expected4 = n_c - val expected5 = or( - eql(int(1), int(0)), - ge(plus(n_c, int(1)), int(0)) - ).untyped() - val expected6 = letIn( - appOp(NameEx("X")), - declOp("X", plus(n_p, int(1))).untypedOperDecl() - ).untyped() - - val exs = Seq(ex1, ex2, ex3, ex4, ex5, ex6) - val expected = Seq(expected1, expected2, expected3, expected4, expected5, expected6) - val actual = exs map transformation - - assert(expected == actual) - - assertThrows[IllegalArgumentException] { - transformation(appOp(NameEx("C"))) - } - assertThrows[IllegalArgumentException] { - transformation(appOp(NameEx("C"), n_a, n_b)) - } - } -}