Skip to content

Commit

Permalink
Merge pull request #1613 from informalsystems/1606-unsound-optimizati…
Browse files Browse the repository at this point in the history
…on-of-cardinalityab-when-a->-b

Fix bug in cardinality optimization
  • Loading branch information
p-offtermatt committed Apr 8, 2022
2 parents 4f14fb8 + d7f3a29 commit 0072050
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,12 @@ class ExprOptimizer(nameGen: UniqueNameGenerator, tracker: TransformationTracker
private def transformCard: PartialFunction[TlaEx, TlaEx] = {
case OperEx(TlaFiniteSetOper.cardinality, OperEx(TlaArithOper.dotdot, left, right)) =>
// A pattern that emerged in issue #748
// Cardinality(a..b) is equivalent to (b - a) + 1.
// Cardinality(a..b) is equivalent to IF a =< b THEN (b - a) + 1 ELSE 0.
val condition = OperEx(TlaArithOper.le, left, right)(boolTag)
val bMinusA = OperEx(TlaArithOper.minus, right, left)(intTag)
OperEx(TlaArithOper.plus, bMinusA, ValEx(TlaInt(1))(intTag))(intTag)
val bMinusAPlusOne = OperEx(TlaArithOper.plus, bMinusA, ValEx(TlaInt(1))(intTag))(intTag)
val zero = ValEx(TlaInt(0))(intTag)
OperEx(TlaControlOper.ifThenElse, condition, bMinusAPlusOne, zero)(intTag)

case OperEx(TlaFiniteSetOper.cardinality, OperEx(TlaSetOper.powerset, set)) =>
// A pattern that emerged in issue #1360
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,15 @@ class TestExprOptimizer extends AnyFunSuite with BeforeAndAfterEach {
optimizer.apply(input)
}

test("""Cardinality(a..b) becomes (b - a) + 1""") {
test("""Cardinality(a..b) becomes IF a =< b THEN (b - a) + 1 ELSE 0""") {
val input = card(dotdot(name("a").as(intT), name("b").as(intT)).as(intSetT)).as(intT)
val output = optimizer.apply(input)
val expected =
plus(minus(name("b").as(intT), name("a").as(intT)).as(intT), int(1).as(intT)).as(intT)
ite(
le(name("a").as(intT), name("b").as(intT)).as(boolT),
plus(minus(name("b").as(intT), name("a").as(intT)).as(intT), int(1)).as(intT),
int(0),
).as(intT)
assert(expected == output)
}

Expand Down

0 comments on commit 0072050

Please sign in to comment.