From ab2bb0759c5f851a81504517e7908bfbdcabd439 Mon Sep 17 00:00:00 2001 From: mario-bucev Date: Sun, 19 May 2024 16:08:07 +0200 Subject: [PATCH] Fix #1530 (#1532) allow using `old` in ensuring of inner functions even if variable is unchanged MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Viktor KunĨak --- .../ImperativeCodeElimination.scala | 16 +++++++----- .../benchmarks/imperative/valid/i1530.scala | 26 +++++++++++++++++++ 2 files changed, 35 insertions(+), 7 deletions(-) create mode 100644 frontends/benchmarks/imperative/valid/i1530.scala diff --git a/core/src/main/scala/stainless/extraction/imperative/ImperativeCodeElimination.scala b/core/src/main/scala/stainless/extraction/imperative/ImperativeCodeElimination.scala index b44445562..644bb9f58 100644 --- a/core/src/main/scala/stainless/extraction/imperative/ImperativeCodeElimination.scala +++ b/core/src/main/scala/stainless/extraction/imperative/ImperativeCodeElimination.scala @@ -30,7 +30,7 @@ class ImperativeCodeElimination(override val s: Trees)(override val t: s.type) import symbols.{given, _} import exprOps._ import exprOps.{ replaceKeepPositions => replace } - + val outerFd = fd /* varsInScope refers to variable declared in the same level scope. * Typically, when entering a nested function body, the scope should be * reset to empty */ @@ -206,7 +206,8 @@ class ImperativeCodeElimination(override val s: Trees)(override val t: s.type) val (fdRes, fdScope, _) = toFunction(bd) fdScope(fdRes) } - val newSpecs = specs.map(rewriteSpecs) + val allParams = outerFd.params ++ state.localsMapping.values.flatMap(_._1.params) ++ fd.params + val newSpecs = specs.map(rewriteSpecs(allParams, _)) val newFd = inner.copy(fullBody = reconstructSpecs(newSpecs, newBody, inner.returnType)) val (bodyRes, bodyScope, bodyFun) = toFunction(b) (bodyRes, (b2: Expr) => LetRec(Seq(newFd.toLocal), bodyScope(b2)).setPos(fd).copiedFrom(expr), bodyFun) @@ -257,8 +258,9 @@ class ImperativeCodeElimination(override val s: Trees)(override val t: s.type) val freshBody = postMap { case Assignment(v, e) => rewritingMap.get(v).map(nv => Assignment(nv, e)) case v: Variable => rewritingMap.get(v).orElse(extraVarReplace.get(v)) - case Old(v: Variable) if rewriteOldExpr && freshVarDecls.contains(v) => - Some(freshVars(freshVarDecls.indexOf(v))) + case Old(v: Variable) if rewriteOldExpr => + if (freshVarDecls.contains(v)) Some(freshVars(freshVarDecls.indexOf(v))) + else Some(v) // This case occurs if `v` is not mutated by the function, in which case old(v) = v case _ => None } (body) val wrappedBody = bodyWrapper(freshBody, freshVarDecls) @@ -425,7 +427,7 @@ class ImperativeCodeElimination(override val s: Trees)(override val t: s.type) } // NOTE: We assume that lets wrapping specs require no rewriting - def rewriteSpecs(spec: Specification)(using State): Specification = { + def rewriteSpecs(fdparams: Seq[ValDef], spec: Specification)(using State): Specification = { def toFn(expr: Expr): Expr = { val (res, scope, _) = toFunction(expr) scope(res) @@ -434,7 +436,7 @@ class ImperativeCodeElimination(override val s: Trees)(override val t: s.type) case Postcondition(ld @ Lambda(params, body)) => // Remove `Old` trees for function parameters on which no effect occurred val newBody = replaceSingle( - fd.params.map(vd => Old(vd.toVariable) -> vd.toVariable).toMap, + fdparams.map(vd => Old(vd.toVariable) -> vd.toVariable).toMap, body ) Postcondition(Lambda(params, toFn(newBody)).copiedFrom(ld)).setPos(spec) @@ -453,7 +455,7 @@ class ImperativeCodeElimination(override val s: Trees)(override val t: s.type) val specced = BodyWithSpecs(fd.fullBody) val newSpecced = specced.copy( - specs = specced.specs.map(rewriteSpecs), + specs = specced.specs.map(rewriteSpecs(fd.params, _)), body = topLevelRewrite(specced.body) ) diff --git a/frontends/benchmarks/imperative/valid/i1530.scala b/frontends/benchmarks/imperative/valid/i1530.scala new file mode 100644 index 000000000..7494959c9 --- /dev/null +++ b/frontends/benchmarks/imperative/valid/i1530.scala @@ -0,0 +1,26 @@ +import stainless.lang.* + +object i1530 { + case class A(var i: BigInt) + + def outer1(a: A): Unit = { + def inner11(a1: A): Unit = { + () + }.ensuring(_ => a1.i == old(a1).i) + + def inner12(a1: A, a2: A): Unit = { + a2.i += 1 + }.ensuring(_ => a1.i == old(a1).i) + } + + def outer2(a: A): Unit = { + def inner21(): Unit = { + }.ensuring(_ => a.i == old(a).i) + } + + def outer3(a: A): Unit = { + def inner3(a2: A): Unit = { + a.i += 1 + }.ensuring(_ => a.i == old(a).i + 1 && a2.i == old(a2).i) + } +} \ No newline at end of file