Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle more corner cases in etaReduce #14628

Merged
merged 1 commit into from
Mar 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1594,6 +1594,9 @@ class Definitions {
yield
nme.apply.specializedFunction(r, List(t1, t2)).asTermName

@tu lazy val FunctionSpecializedApplyNames: collection.Set[Name] =
Function0SpecializedApplyNames ++ Function1SpecializedApplyNames ++ Function2SpecializedApplyNames

def functionArity(tp: Type)(using Context): Int = tp.dropDependentRefinement.dealias.argInfos.length - 1

/** Return underlying context function type (i.e. instance of an ContextFunctionN class)
Expand Down
49 changes: 36 additions & 13 deletions compiler/src/dotty/tools/dotc/transform/EtaReduce.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import MegaPhase.MiniPhase
import core.*
import Symbols.*, Contexts.*, Types.*, Decorators.*
import StdNames.nme
import SymUtils.*
import NameKinds.AdaptedClosureName

/** Rewrite `(x1, ... xN) => f(x1, ... xN)` for N >= 0 to `f`,
* provided `f` is a pure path of function type.
Expand All @@ -15,6 +17,11 @@ import StdNames.nme
* where a context function is expected, unless that value has the
* syntactic form of a context function literal.
*
* Also handle variants of eta-expansions where
* - result f.apply(X_1,...,X_n) is subject to a synthetic cast, or
* - the application uses a specialized apply method, or
* - the closure is adapted (see Erasure#adaptClosure)
*
* Without this phase, when a contextual function is passed as an argument to a
* recursive function, that would have the unfortunate effect of a linear growth
* in transient thunks of identical type wrapped around each other, leading
Expand All @@ -27,20 +34,36 @@ class EtaReduce extends MiniPhase:

override def description: String = EtaReduce.description

override def transformBlock(tree: Block)(using Context): Tree = tree match
case Block((meth : DefDef) :: Nil, closure: Closure)
if meth.symbol == closure.meth.symbol =>
meth.rhs match
case Apply(Select(fn, nme.apply), args)
if meth.paramss.head.corresponds(args)((param, arg) =>
override def transformBlock(tree: Block)(using Context): Tree =

def tryReduce(mdef: DefDef, rhs: Tree): Tree = rhs match
case Apply(Select(fn, name), args)
if (name == nme.apply || defn.FunctionSpecializedApplyNames.contains(name))
&& mdef.paramss.head.corresponds(args)((param, arg) =>
arg.isInstanceOf[Ident] && arg.symbol == param.symbol)
&& isPurePath(fn)
&& fn.tpe <:< tree.tpe
&& defn.isFunctionClass(fn.tpe.widen.typeSymbol) =>
report.log(i"eta reducing $tree --> $fn")
fn
case _ => tree
case _ => tree
&& isPurePath(fn)
&& fn.tpe <:< tree.tpe
&& defn.isFunctionClass(fn.tpe.widen.typeSymbol) =>
report.log(i"eta reducing $tree --> $fn")
fn
case TypeApply(Select(qual, _), _) if rhs.symbol.isTypeCast && rhs.span.isSynthetic =>
tryReduce(mdef, qual)
case _ =>
tree

tree match
case Block((meth: DefDef) :: Nil, expr) if meth.symbol.isAnonymousFunction =>
expr match
case closure: Closure if meth.symbol == closure.meth.symbol =>
tryReduce(meth, meth.rhs)
case Block((adapted: DefDef) :: Nil, closure: Closure)
if adapted.name.is(AdaptedClosureName) && adapted.symbol == closure.meth.symbol =>
tryReduce(meth, meth.rhs)
case _ =>
tree
case _ =>
tree
end transformBlock

end EtaReduce

Expand Down
15 changes: 15 additions & 0 deletions tests/run/i14623.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
object Thunk {
private[this] val impl =
((x: Any) => x).asInstanceOf[(=> Any) => Function0[Any]]

def asFunction0[A](thunk: => A): Function0[A] = impl(thunk).asInstanceOf[Function0[A]]
}

@main def Test =
var i = 0
val f1 = { () => i += 1; "" }
assert(Thunk.asFunction0(f1()) eq f1)
val f2 = { () => i += 1; i }
assert(Thunk.asFunction0(f2()) eq f2)
val f3 = { () => i += 1 }
assert(Thunk.asFunction0(f3()) eq f3)