diff --git a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/ExprOptimizer.scala b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/ExprOptimizer.scala index b761370549..6bee553cd5 100644 --- a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/ExprOptimizer.scala +++ b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/ExprOptimizer.scala @@ -83,9 +83,12 @@ class ExprOptimizer(nameGen: UniqueNameGenerator, tracker: TransformationTracker private def transformCard: PartialFunction[TlaEx, TlaEx] = { case OperEx(TlaFiniteSetOper.cardinality, OperEx(TlaArithOper.dotdot, left, right)) => // A pattern that emerged in issue #748 - // Cardinality(a..b) is equivalent to (b - a) + 1. + // Cardinality(a..b) is equivalent to IF a =< b THEN (b - a) + 1 ELSE 0. + val condition = OperEx(TlaArithOper.le, left, right)(boolTag) val bMinusA = OperEx(TlaArithOper.minus, right, left)(intTag) - OperEx(TlaArithOper.plus, bMinusA, ValEx(TlaInt(1))(intTag))(intTag) + val bMinusAPlusOne = OperEx(TlaArithOper.plus, bMinusA, ValEx(TlaInt(1))(intTag))(intTag) + val zero = ValEx(TlaInt(0))(intTag) + OperEx(TlaControlOper.ifThenElse, condition, bMinusAPlusOne, zero)(intTag) case OperEx(TlaFiniteSetOper.cardinality, OperEx(TlaSetOper.powerset, set)) => // A pattern that emerged in issue #1360 diff --git a/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestExprOptimizer.scala b/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestExprOptimizer.scala index dffea9912c..73e093b4ca 100644 --- a/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestExprOptimizer.scala +++ b/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestExprOptimizer.scala @@ -156,11 +156,15 @@ class TestExprOptimizer extends AnyFunSuite with BeforeAndAfterEach { optimizer.apply(input) } - test("""Cardinality(a..b) becomes (b - a) + 1""") { + test("""Cardinality(a..b) becomes IF a =< b THEN (b - a) + 1 ELSE 0""") { val input = card(dotdot(name("a").as(intT), name("b").as(intT)).as(intSetT)).as(intT) val output = optimizer.apply(input) val expected = - plus(minus(name("b").as(intT), name("a").as(intT)).as(intT), int(1).as(intT)).as(intT) + ite( + le(name("a").as(intT), name("b").as(intT)).as(boolT), + plus(minus(name("b").as(intT), name("a").as(intT)).as(intT), int(1)).as(intT), + int(0), + ).as(intT) assert(expected == output) }