Skip to content

Commit

Permalink
Backport "Reimplement support for type aliases in SAM types" to 3.3.4 (
Browse files Browse the repository at this point in the history
…#21553)

Backports #18317 to 3.3.4-RC2 LTS. The PR fixes a regression introduced
in 3.3.2
  • Loading branch information
WojciechMazur authored Sep 6, 2024
2 parents 1a5bff6 + db126c1 commit 3d0f02c
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 45 deletions.
23 changes: 13 additions & 10 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -344,24 +344,27 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {

/** An anonymous class
*
* new parents { forwarders }
* new parents { termForwarders; typeAliases }
*
* where `forwarders` contains forwarders for all functions in `fns`.
* @param parents a non-empty list of class types
* @param fns a non-empty of functions for which forwarders should be defined in the class.
* The class has the same owner as the first function in `fns`.
* Its position is the union of all functions in `fns`.
* @param parents a non-empty list of class types
* @param termForwarders a non-empty list of forwarding definitions specified by their name and the definition they forward to.
* @param typeMembers a possibly-empty list of type members specified by their name and their right hand side.
*
* The class has the same owner as the first function in `termForwarders`.
* Its position is the union of all symbols in `termForwarders`.
*/
def AnonClass(parents: List[Type], fns: List[TermSymbol], methNames: List[TermName])(using Context): Block = {
AnonClass(fns.head.owner, parents, fns.map(_.span).reduceLeft(_ union _)) { cls =>
def forwarder(fn: TermSymbol, name: TermName) = {
def AnonClass(parents: List[Type], termForwarders: List[(TermName, TermSymbol)],
typeMembers: List[(TypeName, TypeBounds)] = Nil)(using Context): Block = {
AnonClass(termForwarders.head._2.owner, parents, termForwarders.map(_._2.span).reduceLeft(_ union _)) { cls =>
def forwarder(name: TermName, fn: TermSymbol) = {
val fwdMeth = fn.copy(cls, name, Synthetic | Method | Final).entered.asTerm
for overridden <- fwdMeth.allOverriddenSymbols do
if overridden.is(Extension) then fwdMeth.setFlag(Extension)
if !overridden.is(Deferred) then fwdMeth.setFlag(Override)
DefDef(fwdMeth, ref(fn).appliedToArgss(_))
}
fns.lazyZip(methNames).map(forwarder)
termForwarders.map((name, sym) => forwarder(name, sym)) ++
typeMembers.map((name, info) => TypeDef(newSymbol(cls, name, Synthetic, info).entered))
}
}

Expand Down
59 changes: 38 additions & 21 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5534,13 +5534,16 @@ object Types extends TypeUtils {
* and PolyType not allowed!) according to `possibleSamMethods`.
* - can be instantiated without arguments or with just () as argument.
*
* Additionally, a SAM type may contain type aliases refinements if they refine
* an existing type member.
*
* The pattern `SAMType(samMethod, samParent)` matches a SAM type, where `samMethod` is the
* type of the single abstract method and `samParent` is a subtype of the matched
* SAM type which has been stripped of wildcards to turn it into a valid parent
* type.
*/
object SAMType {
/** If possible, return a type which is both a subtype of `origTp` and a type
/** If possible, return a type which is both a subtype of `origTp` and a (possibly refined) type
* application of `samClass` where none of the type arguments are
* wildcards (thus making it a valid parent type), otherwise return
* NoType.
Expand Down Expand Up @@ -5570,27 +5573,41 @@ object Types extends TypeUtils {
* we arbitrarily pick the upper-bound.
*/
def samParent(origTp: Type, samClass: Symbol, samMeth: Symbol)(using Context): Type =
val tp = origTp.baseType(samClass)
val tp0 = origTp.baseType(samClass)

/** Copy type aliases refinements to `toTp` from `fromTp` */
def withRefinements(toType: Type, fromTp: Type): Type = fromTp.dealias match
case RefinedType(fromParent, name, info: TypeAlias) if tp0.member(name).exists =>
val parent1 = withRefinements(toType, fromParent)
RefinedType(toType, name, info)
case _ => toType
val tp = withRefinements(tp0, origTp)

if !(tp <:< origTp) then NoType
else tp match
case tp @ AppliedType(tycon, args) if tp.hasWildcardArg =>
val accu = new TypeAccumulator[VarianceMap[Symbol]]:
def apply(vmap: VarianceMap[Symbol], t: Type): VarianceMap[Symbol] = t match
case tp: TypeRef if tp.symbol.isAllOf(ClassTypeParam) =>
vmap.recordLocalVariance(tp.symbol, variance)
case _ =>
foldOver(vmap, t)
val vmap = accu(VarianceMap.empty, samMeth.info)
val tparams = tycon.typeParamSymbols
val args1 = args.zipWithConserve(tparams):
case (arg @ TypeBounds(lo, hi), tparam) =>
val v = vmap.computedVariance(tparam)
if v.uncheckedNN < 0 then lo
else hi
case (arg, _) => arg
tp.derivedAppliedType(tycon, args1)
case _ =>
tp
else
def approxWildcardArgs(tp: Type): Type = tp match
case tp @ AppliedType(tycon, args) if tp.hasWildcardArg =>
val accu = new TypeAccumulator[VarianceMap[Symbol]]:
def apply(vmap: VarianceMap[Symbol], t: Type): VarianceMap[Symbol] = t match
case tp: TypeRef if tp.symbol.isAllOf(ClassTypeParam) =>
vmap.recordLocalVariance(tp.symbol, variance)
case _ =>
foldOver(vmap, t)
val vmap = accu(VarianceMap.empty, samMeth.info)
val tparams = tycon.typeParamSymbols
val args1 = args.zipWithConserve(tparams):
case (arg @ TypeBounds(lo, hi), tparam) =>
val v = vmap.computedVariance(tparam)
if v.uncheckedNN < 0 then lo
else hi
case (arg, _) => arg
tp.derivedAppliedType(tycon, args1)
case tp @ RefinedType(parent, name, info) =>
tp.derivedRefinedType(approxWildcardArgs(parent), name, info)
case _ =>
tp
approxWildcardArgs(tp)
end samParent

def samClass(tp: Type)(using Context): Symbol = tp match
case tp: ClassInfo =>
Expand Down
33 changes: 19 additions & 14 deletions compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import core.*
import Scopes.newScope
import Contexts.*, Symbols.*, Types.*, Flags.*, Decorators.*, StdNames.*, Constants.*
import MegaPhase.*
import Names.TypeName
import Symbols.*
import NullOpsDecorator.*
import ast.untpd

Expand Down Expand Up @@ -50,16 +52,28 @@ class ExpandSAMs extends MiniPhase:
case tpe if defn.isContextFunctionType(tpe) =>
tree
case SAMType(_, tpe) if tpe.isRef(defn.PartialFunctionClass) =>
val tpe1 = checkRefinements(tpe, fn)
toPartialFunction(tree, tpe1)
toPartialFunction(tree, tpe)
case SAMType(_, tpe) if ExpandSAMs.isPlatformSam(tpe.classSymbol.asClass) =>
checkRefinements(tpe, fn)
tree
case tpe =>
val tpe1 = checkRefinements(tpe.stripNull, fn)
// A SAM type is allowed to have type aliases refinements (see
// SAMType#samParent) which must be converted into type members if
// the closure is desugared into a class.
val refinements = collection.mutable.ListBuffer[(TypeName, TypeAlias)]()
def collectAndStripRefinements(tp: Type): Type = tp match
case RefinedType(parent, name, info: TypeAlias) =>
val res = collectAndStripRefinements(parent)
refinements += ((name.asTypeName, info))
res
case _ => tp
val tpe1 = collectAndStripRefinements(tpe)
val Seq(samDenot) = tpe1.possibleSamMethods
cpy.Block(tree)(stats,
AnonClass(tpe1 :: Nil, fn.symbol.asTerm :: Nil, samDenot.symbol.asTerm.name :: Nil))
AnonClass(List(tpe1),
List(samDenot.symbol.asTerm.name -> fn.symbol.asTerm),
refinements.toList
)
)
}
case _ =>
tree
Expand Down Expand Up @@ -170,13 +184,4 @@ class ExpandSAMs extends MiniPhase:
List(isDefinedAtDef, applyOrElseDef)
}
}

private def checkRefinements(tpe: Type, tree: Tree)(using Context): Type = tpe.dealias match {
case RefinedType(parent, name, _) =>
if (name.isTermName && tpe.member(name).symbol.ownersIterator.isEmpty) // if member defined in the refinement
report.error(em"Lambda does not define $name", tree.srcPos)
checkRefinements(parent, tree)
case tpe =>
tpe
}
end ExpandSAMs
15 changes: 15 additions & 0 deletions tests/run/i18315.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
trait Sam1:
type T
def apply(x: T): T

trait Sam2:
var x: Int = 1 // To force anonymous class generation
type T
def apply(x: T): T

object Test:
def main(args: Array[String]): Unit =
val s1: Sam1 { type T = String } = x => x.trim
s1.apply("foo")
val s2: Sam2 { type T = Int } = x => x + 1
s2.apply(1)

0 comments on commit 3d0f02c

Please sign in to comment.