diff --git a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/SetMembershipSimplifier.scala b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/SetMembershipSimplifier.scala index e0faf5917c..771b6a0570 100644 --- a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/SetMembershipSimplifier.scala +++ b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/SetMembershipSimplifier.scala @@ -7,9 +7,18 @@ import at.forsyte.apalache.tla.lir.transformations.{LanguageWatchdog, Transforma import at.forsyte.apalache.tla.lir.values._ /** - * A simplifier that rewrites vacuously true membership tests based on type information. + * A simplifier that rewrites expressions commonly found in `TypeOK`. Assumes expressions to be well-typed. * - * For example, `x \in BOOLEAN` is rewritten to `TRUE` if x is typed BoolT1. + * After Apalache's type-checking, we can rewrite some expressions to simpler forms. For example, the (after + * type-checking) vacuously true `x \in BOOLEAN` is rewritten to `TRUE` (as `x` must be a `BoolT1`). + * + * We currently perform the following simplifications (for type-defining sets TDS, see [[isTypeDefining]]): + * - `n \in Nat` ~> `x >= 0` + * - `b \in BOOLEAN`, `i \in Int`, `r \in Real` ~> `TRUE` + * - `seq \in Seq(TDS)` ~> `TRUE` + * - `set \in SUBSET TDS` ~> `TRUE` + * - `fun \in [TDS1 -> TDS2]` ~> `TRUE` + * - `fun \in [Dom -> TDS]` ~> `DOMAIN fun = Dom` * * @author * Thomas Pani @@ -27,46 +36,57 @@ class SetMembershipSimplifier(tracker: TransformationTracker) extends AbstractTr } /** - * Returns the type of a TLA+ predefined set, if rewriting set membership to TRUE is applicable. In particular, it is - * *not* applicable to `Nat`, since `i \in Nat` does not hold for all `IntT1`-typed `i`. + * Returns true iff the passed TLA+ expression is a type-defining set. Type-defining sets contain all of the values of + * their respective element type. + * + * The type-defining sets are inductively defined as + * - the predefined sets BOOLEAN, Int, Real, STRING, + * - sets of sequences over type-defining sets, e.g., Seq(BOOLEAN), Seq(Int), Seq(Seq(Int)), Seq(SUBSET Int), ... + * - power sets of type-defining sets, e.g., SUBSET BOOLEAN, SUBSET Int, SUBSET Seq(Int), ... + * - sets of functions over type-defining sets, e.g., [Int -> BOOLEAN], ... + * + * In particular, `Nat` is not type-defining, nor are sequence sets / power sets thereof, since `i \in Nat` does not + * hold for all `IntT1`-typed `i`. */ - private def typeOfSupportedPredefSet: PartialFunction[TlaPredefSet, TlaType1] = { - case TlaBoolSet => BoolT1() - case TlaIntSet => IntT1() - case TlaRealSet => RealT1() - case TlaStrSet => StrT1() - // intentionally omits TlaNatSet, see above. - } + private def isTypeDefining: Function[TlaEx, Boolean] = { + // base case: BOOLEAN, Int, Real, STRING + case ValEx(TlaBoolSet) | ValEx(TlaIntSet) | ValEx(TlaRealSet) | ValEx(TlaStrSet) => true - /** - * Checks if this transformation is applicable (see [[typeOfSupportedPredefSet]]) to a TLA+ predefined set `ps` of - * primitives, and if the types of `name` and `ps` match. - */ - private def isApplicable(name: TlaEx, ps: TlaPredefSet): Boolean = - typeOfSupportedPredefSet.isDefinedAt(ps) && name.typeTag == Typed(typeOfSupportedPredefSet(ps)) + // inductive cases: + // 1. Seq(s) for a type-defining set `s` + case OperEx(TlaSetOper.seqSet, set) => isTypeDefining(set) + // 2. SUBSET s for a type-defining set `s` + case OperEx(TlaSetOper.powerset, set) => isTypeDefining(set) + // 3. [s1 -> s2] for type-defining sets `s1` and `s2 + case OperEx(TlaSetOper.funSet, set1, set2) => isTypeDefining(set1) && isTypeDefining(set2) - /** - * Checks if this transformation is applicable (see [[typeOfSupportedPredefSet]]) to a TLA+ predefined set of - * sequences (`Seq(_)`) `ps`, and if the types of `name` and `ps` match. - */ - private def isApplicableSeq(name: TlaEx, ps: TlaPredefSet): Boolean = - typeOfSupportedPredefSet.isDefinedAt(ps) && name.typeTag == Typed(SeqT1(typeOfSupportedPredefSet(ps))) + // otherwise + case _ => false + } /** - * Rewrites vacuously true membership tests based on type information, and rewrites `i \in Nat` to `i \ge 0`. + * Simplifies expressions commonly found in `TypeOK`, assuming they are well-typed. * - * For example, `x \in BOOLEAN` is rewritten to `TRUE` if `x` is typed `BoolT1`. + * @see + * [[SetMembershipSimplifier]] for a full list of supported rewritings. */ private def transformMembership: PartialFunction[TlaEx, TlaEx] = { - // n \in Nat -> x >= 0 - case OperEx(TlaSetOper.in, name, ValEx(TlaNatSet)) if name.typeTag == Typed(IntT1()) => - OperEx(TlaArithOper.ge, name, ValEx(TlaInt(0))(intTag))(boolTag) - // b \in BOOLEAN, i \in Int, r \in Real -> TRUE - case OperEx(TlaSetOper.in, name, ValEx(ps: TlaPredefSet)) if isApplicable(name, ps) => trueVal - // seq \in Seq(_) -> TRUE - case OperEx(TlaSetOper.in, name, OperEx(TlaSetOper.seqSet, ValEx(ps: TlaPredefSet))) if isApplicableSeq(name, ps) => - trueVal - // return `ex` unchanged + case ex @ OperEx(TlaSetOper.in, name, set) => + set match { + // x \in TDS ~> TRUE + case set if isTypeDefining(set) => trueVal + + // n \in Nat ~> x >= 0 + case ValEx(TlaNatSet) => OperEx(TlaArithOper.ge, name, ValEx(TlaInt(0))(intTag))(boolTag) + + // fun \in [Dom -> TDS] ~> DOMAIN fun = Dom (fun \in [TDS1 -> TDS2] is handled above) + case OperEx(TlaSetOper.funSet, domain, set2) if isTypeDefining(set2) => + OperEx(TlaOper.eq, OperEx(TlaFunOper.domain, name)(domain.typeTag), domain)(boolTag) + + // otherwise, return `ex` unchanged + case _ => ex + } + // return non-set membership expressions unchanged case ex => ex } } diff --git a/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestSetMembershipSimplifier.scala b/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestSetMembershipSimplifier.scala index 2e1e3e730e..299836f6e3 100644 --- a/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestSetMembershipSimplifier.scala +++ b/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestSetMembershipSimplifier.scala @@ -28,6 +28,7 @@ class TestSetMembershipSimplifier private val boolName = tla.name("b").as(BoolT1()) private val strName = tla.name("s").as(StrT1()) private val intName = tla.name("i").as(IntT1()) + private val funName = tla.name("fun").as(FunT1(IntT1(), BoolT1())) private val boolSet = tla.booleanSet().as(SetT1(BoolT1())) private val strSet = tla.stringSet().as(SetT1(StrT1())) @@ -46,6 +47,19 @@ class TestSetMembershipSimplifier private val intSeqSet = tla.seqSet(tla.intSet()).as(SetT1(SeqT1(IntT1()))) private val natSeqSet = tla.seqSet(tla.natSet()).as(SetT1(SeqT1(IntT1()))) + private val boolSetVal = tla.enumSet(boolVal, boolName).as(SetT1(BoolT1())) + private val strSetVal = tla.enumSet(strVal, strName).as(SetT1(StrT1())) + private val intSetVal = tla.enumSet(intVal, intName).as(SetT1(IntT1())) + + private val boolSetName = tla.name("boolSet").as(SetT1(BoolT1())) + private val strSetName = tla.name("strSet").as(SetT1(StrT1())) + private val intSetName = tla.name("intSet").as(SetT1(IntT1())) + + private val boolPowerset = tla.seqSet(tla.booleanSet()).as(SetT1(SeqT1(BoolT1()))) + private val strPowerset = tla.seqSet(tla.stringSet()).as(SetT1(SeqT1(StrT1()))) + private val intPowerset = tla.seqSet(tla.intSet()).as(SetT1(SeqT1(IntT1()))) + private val natPowerset = tla.seqSet(tla.natSet()).as(SetT1(SeqT1(IntT1()))) + val expressions = List( (boolName, boolVal, boolSet), (strName, strVal, strSet), @@ -53,6 +67,9 @@ class TestSetMembershipSimplifier (boolSeqName, boolSeqVal, boolSeqSet), (strSeqName, strSeqVal, strSeqSet), (intSeqName, intSeqVal, intSeqSet), + (boolSetName, boolSetVal, boolPowerset), + (strSetName, strSetVal, strPowerset), + (intSetName, intSetVal, intPowerset), ) override def beforeEach(): Unit = { @@ -60,40 +77,90 @@ class TestSetMembershipSimplifier } test("simplifies appropriately-typed set membership") { + // i \in Nat ~> i >= 0 + val intNameInNat = tla.in(intName, tla.natSet()).as(BoolT1()) + val intValInNat = tla.in(intVal, tla.natSet()).as(BoolT1()) + simplifier(intNameInNat) shouldBe tla.ge(intName, tla.int(0)).as(BoolT1()) + simplifier(intValInNat) shouldBe tla.ge(intVal, tla.int(0)).as(BoolT1()) + + /* *** tests for all supported types of applicable sets *** */ + expressions.foreach { case (name, value, set) => + // name \in ApplicableSet ~> TRUE + // e.g., b \in BOOLEAN, i \in Int, ... ~> TRUE val inputName = tla.in(name, set).as(BoolT1()) simplifier(inputName) shouldBe tlaTrue + // literal \in ApplicableSet ~> TRUE + // e.g., TRUE \in BOOLEAN, 42 \in Int, ... ~> TRUE val inputValue = tla.in(value, set).as(BoolT1()) simplifier(inputValue) shouldBe tlaTrue + /* *** nested cases *** */ + + // name \in ApplicableSet /\ _ ~> TRUE + // e.g., i \in Int /\ _, ... ~> TRUE val nestedInputName = tla.and(tla.in(name, set).as(BoolT1()), tlaTrue).as(BoolT1()) simplifier(nestedInputName) shouldBe tla.and(tlaTrue, tlaTrue).as(BoolT1()) + // literal \in ApplicableSet /\ _ ~> TRUE + // e.g., 42 \in Int /\ _, ... ~> TRUE val nestedInputValue = tla.and(tla.in(name, set).as(BoolT1()), tlaTrue).as(BoolT1()) simplifier(nestedInputValue) shouldBe tla.and(tlaTrue, tlaTrue).as(BoolT1()) - val intNameInNat = tla.in(intName, tla.natSet()).as(BoolT1()) - val intValInNat = tla.in(intVal, tla.natSet()).as(BoolT1()) - simplifier(intNameInNat) shouldBe tla.ge(intName, tla.int(0)).as(BoolT1()) - simplifier(intValInNat) shouldBe tla.ge(intVal, tla.int(0)).as(BoolT1()) - } - } - - test("returns inappropriately-typed set membership unchanged") { - expressions.foreach { case (name, value, _) => - expressions.filter { case (name2, _, _) => name != name2 }.foreach { case (_, _, otherSet) => - val inputName = tla.in(name, otherSet).as(BoolT1()) - simplifier(inputName) shouldBe inputName - - val inputValue = tla.in(value, otherSet).as(BoolT1()) - simplifier(inputValue) shouldBe inputValue + // fun \in [ApplicableSet1 -> ApplicableSet2], ... ~> TRUE + expressions.foreach { case (name2, _, set2) => + val funSetType = SetT1(FunT1(name.typeTag.asTlaType1(), name2.typeTag.asTlaType1())) + val funInFunSet = tla.in(funName, tla.funSet(set, set2).as(funSetType)).as(BoolT1()) + simplifier(funInFunSet) shouldBe tlaTrue } } + /* *** tests of particular expressions *** */ + + // <<{{1}}>> \in Seq(SUBSET Int) ~> TRUE + val setOfSetOfInt = SetT1(SetT1(IntT1())) + val seqOfSetOfSetOfInt = SeqT1(setOfSetOfInt) + val nestedSeqSubsetVal = + tla.tuple(tla.enumSet(intSetVal).as(setOfSetOfInt)).as(SeqT1(setOfSetOfInt)).as(seqOfSetOfSetOfInt) + val nestedSeqSubsetTest = + tla.in(nestedSeqSubsetVal, tla.seqSet(tla.powSet(intSet).as(setOfSetOfInt)).as(seqOfSetOfSetOfInt)).as(BoolT1()) + simplifier(nestedSeqSubsetTest) shouldBe tlaTrue + + // {<<1>>} \in SUBSET (Seq(Int)) ~> TRUE + val setOfSeqOfInt = SetT1(SeqT1(IntT1())) + val nestedSubsetSeqVal = tla.enumSet(tla.tuple(intVal).as(SeqT1(IntT1()))).as(setOfSeqOfInt) + val nestedSubsetSeqTest = tla.in(nestedSubsetSeqVal, tla.powSet(intSeqSet).as(setOfSeqOfInt)).as(BoolT1()) + simplifier(nestedSubsetSeqTest) shouldBe tlaTrue + + // fun \in [Seq(SUBSET Int) -> SUBSET Seq(BOOLEAN)], ... ~> TRUE + val intPowersetSeqType = SeqT1(SetT1(IntT1())) + val boolSeqPowersetType = SetT1(SeqT1(BoolT1())) + val nestedFunSetType = SetT1(FunT1(intPowersetSeqType, boolSeqPowersetType)) + val nestedInput = tla + .in(funName, + tla + .funSet(tla.seqSet(intSeqSet).as(intPowersetSeqType), tla.powSet(boolSeqSet).as(boolSeqPowersetType)) + .as(nestedFunSetType)) + .as(BoolT1()) + simplifier(nestedInput) shouldBe tlaTrue + + // fun \in [RM -> PredefSet], ... ~> DOMAIN fun = RM + val domain = tla.name("RM").as(SetT1(IntT1())) + val funSetType = SetT1(FunT1(BoolT1(), IntT1())) + val funConstToBoolean = tla.in(funName, tla.funSet(domain, boolSet).as(funSetType)).as(BoolT1()) + simplifier(funConstToBoolean) shouldBe tla.eql(tla.dom(funName).as(SetT1(IntT1())), domain).as(BoolT1()) + } + + test("returns myInt \\in Nat unchanged") { val intSeqNameInSeqNat = tla.in(intSeqName, natSeqSet).as(BoolT1()) val intSeqValInSeqNat = tla.in(intSeqVal, natSeqSet).as(BoolT1()) simplifier(intSeqNameInSeqNat) shouldBe intSeqNameInSeqNat simplifier(intSeqValInSeqNat) shouldBe intSeqValInSeqNat + + val intSetNameInNatPowerset = tla.in(intSetName, natPowerset).as(BoolT1()) + val intSetValInNatPowerset = tla.in(intSetVal, natPowerset).as(BoolT1()) + simplifier(intSetNameInNatPowerset) shouldBe intSetNameInNatPowerset + simplifier(intSetValInNatPowerset) shouldBe intSetValInNatPowerset } }