Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #1530 #1532

Merged
merged 2 commits into from
May 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class ImperativeCodeElimination(override val s: Trees)(override val t: s.type)
import symbols.{given, _}
import exprOps._
import exprOps.{ replaceKeepPositions => replace }

val outerFd = fd
/* varsInScope refers to variable declared in the same level scope.
* Typically, when entering a nested function body, the scope should be
* reset to empty */
Expand Down Expand Up @@ -206,7 +206,8 @@ class ImperativeCodeElimination(override val s: Trees)(override val t: s.type)
val (fdRes, fdScope, _) = toFunction(bd)
fdScope(fdRes)
}
val newSpecs = specs.map(rewriteSpecs)
val allParams = outerFd.params ++ state.localsMapping.values.flatMap(_._1.params) ++ fd.params
val newSpecs = specs.map(rewriteSpecs(allParams, _))
val newFd = inner.copy(fullBody = reconstructSpecs(newSpecs, newBody, inner.returnType))
val (bodyRes, bodyScope, bodyFun) = toFunction(b)
(bodyRes, (b2: Expr) => LetRec(Seq(newFd.toLocal), bodyScope(b2)).setPos(fd).copiedFrom(expr), bodyFun)
Expand Down Expand Up @@ -257,8 +258,9 @@ class ImperativeCodeElimination(override val s: Trees)(override val t: s.type)
val freshBody = postMap {
case Assignment(v, e) => rewritingMap.get(v).map(nv => Assignment(nv, e))
case v: Variable => rewritingMap.get(v).orElse(extraVarReplace.get(v))
case Old(v: Variable) if rewriteOldExpr && freshVarDecls.contains(v) =>
Some(freshVars(freshVarDecls.indexOf(v)))
case Old(v: Variable) if rewriteOldExpr =>
if (freshVarDecls.contains(v)) Some(freshVars(freshVarDecls.indexOf(v)))
else Some(v) // This case occurs if `v` is not mutated by the function, in which case old(v) = v
case _ => None
} (body)
val wrappedBody = bodyWrapper(freshBody, freshVarDecls)
Expand Down Expand Up @@ -425,7 +427,7 @@ class ImperativeCodeElimination(override val s: Trees)(override val t: s.type)
}

// NOTE: We assume that lets wrapping specs require no rewriting
def rewriteSpecs(spec: Specification)(using State): Specification = {
def rewriteSpecs(fdparams: Seq[ValDef], spec: Specification)(using State): Specification = {
def toFn(expr: Expr): Expr = {
val (res, scope, _) = toFunction(expr)
scope(res)
Expand All @@ -434,7 +436,7 @@ class ImperativeCodeElimination(override val s: Trees)(override val t: s.type)
case Postcondition(ld @ Lambda(params, body)) =>
// Remove `Old` trees for function parameters on which no effect occurred
val newBody = replaceSingle(
fd.params.map(vd => Old(vd.toVariable) -> vd.toVariable).toMap,
fdparams.map(vd => Old(vd.toVariable) -> vd.toVariable).toMap,
body
)
Postcondition(Lambda(params, toFn(newBody)).copiedFrom(ld)).setPos(spec)
Expand All @@ -453,7 +455,7 @@ class ImperativeCodeElimination(override val s: Trees)(override val t: s.type)
val specced = BodyWithSpecs(fd.fullBody)

val newSpecced = specced.copy(
specs = specced.specs.map(rewriteSpecs),
specs = specced.specs.map(rewriteSpecs(fd.params, _)),
body = topLevelRewrite(specced.body)
)

Expand Down
26 changes: 26 additions & 0 deletions frontends/benchmarks/imperative/valid/i1530.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import stainless.lang.*

object i1530 {
case class A(var i: BigInt)

def outer1(a: A): Unit = {
def inner11(a1: A): Unit = {
()
}.ensuring(_ => a1.i == old(a1).i)

def inner12(a1: A, a2: A): Unit = {
a2.i += 1
}.ensuring(_ => a1.i == old(a1).i)
}

def outer2(a: A): Unit = {
def inner21(): Unit = {
}.ensuring(_ => a.i == old(a).i)
}

def outer3(a: A): Unit = {
def inner3(a2: A): Unit = {
a.i += 1
}.ensuring(_ => a.i == old(a).i + 1 && a2.i == old(a2).i)
}
}