Skip to content

Commit

Permalink
Fix #9028: Introduce super traits
Browse files Browse the repository at this point in the history
and eliminate super traits in widenInferred
  • Loading branch information
odersky committed Jun 9, 2020
1 parent 86bd4e9 commit ff9798d
Show file tree
Hide file tree
Showing 18 changed files with 124 additions and 44 deletions.
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
case class Inline()(implicit @constructorOnly src: SourceFile) extends Mod(Flags.Inline)

case class Transparent()(implicit @constructorOnly src: SourceFile) extends Mod(Flags.EmptyFlags)

case class Super()(implicit @constructorOnly src: SourceFile) extends Mod(Flags.SuperTrait)
}

/** Modifiers and annotations for definitions
Expand Down
49 changes: 31 additions & 18 deletions compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,9 @@ trait ConstraintHandling[AbstractContext] {
* (i.e. `inst.widenSingletons <:< bound` succeeds with satisfiable constraint)
* 2. If `inst` is a union type, approximate the union type from above by an intersection
* of all common base types, provided the result is a subtype of `bound`.
* 3. (currently not enabled, see #9028) If `inst` is an intersection with some restricted base types, drop
* the restricted base types from the intersection, provided the result is a subtype of `bound`.
* 3. If `inst` a super trait instance or an intersection with some super trait
* parents, replace all super trait instances with AnyRef (or Any, if the trait
* is a universal trait) as long as the result is a subtype of `bound`.
*
* Don't do these widenings if `bound` is a subtype of `scala.Singleton`.
* Also, if the result of these widenings is a TypeRef to a module class,
Expand All @@ -313,21 +314,36 @@ trait ConstraintHandling[AbstractContext] {
*/
def widenInferred(inst: Type, bound: Type)(implicit actx: AbstractContext): Type =

def isRestricted(tp: Type) = tp.typeSymbol == defn.EnumValueClass // for now, to be generalized later
def dropSuperTraits(tp: Type): Type =
var keep: Set[Type] = Set() // types to keep since otherwise bound would not fit
var lastDropped: Type = NoType // the last type dropped in dropOneSuperTrait

def dropOneSuperTrait(tp: Type): Type =
val tpd = tp.dealias
if tpd.typeSymbol.isSuperTrait && !tpd.isLambdaSub && !keep.contains(tpd) then
lastDropped = tpd
if tpd.derivesFrom(defn.ObjectClass) then defn.ObjectType else defn.AnyType
else tpd match
case AndType(tp1, tp2) =>
val tp1w = dropOneSuperTrait(tp1)
if tp1w ne tp1 then tp1w & tp2
else
val tp2w = dropOneSuperTrait(tp2)
if tp2w ne tp2 then tp1 & tp2w
else tpd
case _ =>
tp

def dropRestricted(tp: Type): Type = tp.dealias match
case tpd @ AndType(tp1, tp2) =>
if isRestricted(tp1) then tp2
else if isRestricted(tp2) then tp1
def recur(tp: Type): Type =
val tpw = dropOneSuperTrait(tp)
if tpw eq tp then tp
else if tpw <:< bound then recur(tpw)
else
val tpw = tpd.derivedAndType(dropRestricted(tp1), dropRestricted(tp2))
if tpw ne tpd then tpw else tp
case _ =>
tp
keep += lastDropped
recur(tp)

def widenRestricted(tp: Type) =
val tpw = dropRestricted(tp)
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
recur(tp)
end dropSuperTraits

def widenOr(tp: Type) =
val tpw = tp.widenUnion
Expand All @@ -343,10 +359,7 @@ trait ConstraintHandling[AbstractContext] {

val wideInst =
if isSingleton(bound) then inst
else /*widenRestricted*/(widenOr(widenSingle(inst)))
// widenRestricted is currently not called since it's special cased in `dropEnumValue`
// in `Namer`. It's left in here in case we want to generalize the scheme to other
// "protected inheritance" classes.
else dropSuperTraits(widenOr(widenSingle(inst)))
wideInst match
case wideInst: TypeRef if wideInst.symbol.is(Module) =>
TermRef(wideInst.prefix, wideInst.symbol.sourceModule)
Expand Down
4 changes: 3 additions & 1 deletion compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,6 @@ class Definitions {
@tu lazy val EnumClass: ClassSymbol = ctx.requiredClass("scala.Enum")
@tu lazy val Enum_ordinal: Symbol = EnumClass.requiredMethod(nme.ordinal)

@tu lazy val EnumValueClass: ClassSymbol = ctx.requiredClass("scala.runtime.EnumValue")
@tu lazy val EnumValuesClass: ClassSymbol = ctx.requiredClass("scala.runtime.EnumValues")
@tu lazy val ProductClass: ClassSymbol = ctx.requiredClass("scala.Product")
@tu lazy val Product_canEqual : Symbol = ProductClass.requiredMethod(nme.canEqual_)
Expand Down Expand Up @@ -1308,6 +1307,9 @@ class Definitions {
def isInfix(sym: Symbol)(implicit ctx: Context): Boolean =
(sym eq Object_eq) || (sym eq Object_ne)

@tu lazy val assumedSuperTraits =
Set(ComparableClass, JavaSerializableClass, ProductClass, SerializableClass)

// ----- primitive value class machinery ------------------------------------------

/** This class would also be obviated by the implicit function type design */
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/core/SymDenotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1166,6 +1166,10 @@ object SymDenotations {
final def isEffectivelySealed(using Context): Boolean =
isOneOf(FinalOrSealed) || isClass && !isOneOf(EffectivelyOpenFlags)

final def isSuperTrait(using Context): Boolean =
isClass
&& (is(SuperTrait) || defn.assumedSuperTraits.contains(symbol.asClass))

/** The class containing this denotation which has the given effective name. */
final def enclosingClassNamed(name: Name)(implicit ctx: Context): Symbol = {
val cls = enclosingClass
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,7 @@ class TreePickler(pickler: TastyPickler) {
if (flags.is(Sealed)) writeModTag(SEALED)
if (flags.is(Abstract)) writeModTag(ABSTRACT)
if (flags.is(Trait)) writeModTag(TRAIT)
if flags.is(SuperTrait) then writeModTag(SUPERTRAIT)
if (flags.is(Covariant)) writeModTag(COVARIANT)
if (flags.is(Contravariant)) writeModTag(CONTRAVARIANT)
if (flags.is(Opaque)) writeModTag(OPAQUE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,7 @@ class TreeUnpickler(reader: TastyReader,
case STATIC => addFlag(JavaStatic)
case OBJECT => addFlag(Module)
case TRAIT => addFlag(Trait)
case SUPERTRAIT => addFlag(SuperTrait)
case ENUM => addFlag(Enum)
case LOCAL => addFlag(Local)
case SYNTHETIC => addFlag(Synthetic)
Expand Down
4 changes: 3 additions & 1 deletion compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3434,7 +3434,7 @@ object Parsers {
}
}

/** TmplDef ::= ([‘case’] ‘class’ | ‘trait’) ClassDef
/** TmplDef ::= ([‘case’] ‘class’ | [‘super’] ‘trait’) ClassDef
* | [‘case’] ‘object’ ObjectDef
* | ‘enum’ EnumDef
* | ‘given’ GivenDef
Expand All @@ -3444,6 +3444,8 @@ object Parsers {
in.token match {
case TRAIT =>
classDef(start, posMods(start, addFlag(mods, Trait)))
case SUPERTRAIT =>
classDef(start, posMods(start, addFlag(mods, Trait | SuperTrait)))
case CLASS =>
classDef(start, posMods(start, mods))
case CASECLASS =>
Expand Down
7 changes: 6 additions & 1 deletion compiler/src/dotty/tools/dotc/parsing/Scanners.scala
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,8 @@ object Scanners {
currentRegion = r.outer
case _ =>

/** - Join CASE + CLASS => CASECLASS, CASE + OBJECT => CASEOBJECT, SEMI + ELSE => ELSE, COLON + <EOL> => COLONEOL
/** - Join CASE + CLASS => CASECLASS, CASE + OBJECT => CASEOBJECT, SUPER + TRAIT => SUPERTRAIT
* SEMI + ELSE => ELSE, COLON + <EOL> => COLONEOL
* - Insert missing OUTDENTs at EOF
*/
def postProcessToken(): Unit = {
Expand All @@ -602,6 +603,10 @@ object Scanners {
if (token == CLASS) fuse(CASECLASS)
else if (token == OBJECT) fuse(CASEOBJECT)
else reset()
case SUPER =>
lookAhead()
if token == TRAIT then fuse(SUPERTRAIT)
else reset()
case SEMI =>
lookAhead()
if (token != ELSE) reset()
Expand Down
5 changes: 3 additions & 2 deletions compiler/src/dotty/tools/dotc/parsing/Tokens.scala
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ object Tokens extends TokensCommon {
final val ERASED = 63; enter(ERASED, "erased")
final val GIVEN = 64; enter(GIVEN, "given")
final val EXPORT = 65; enter(EXPORT, "export")
final val MACRO = 66; enter(MACRO, "macro") // TODO: remove
final val SUPERTRAIT = 66; enter(SUPERTRAIT, "super trait")
final val MACRO = 67; enter(MACRO, "macro") // TODO: remove

/** special symbols */
final val NEWLINE = 78; enter(NEWLINE, "end of statement", "new line")
Expand Down Expand Up @@ -233,7 +234,7 @@ object Tokens extends TokensCommon {
final val canStartTypeTokens: TokenSet = literalTokens | identifierTokens | BitSet(
THIS, SUPER, USCORE, LPAREN, AT)

final val templateIntroTokens: TokenSet = BitSet(CLASS, TRAIT, OBJECT, ENUM, CASECLASS, CASEOBJECT)
final val templateIntroTokens: TokenSet = BitSet(CLASS, TRAIT, OBJECT, ENUM, CASECLASS, CASEOBJECT, SUPERTRAIT)

final val dclIntroTokens: TokenSet = BitSet(DEF, VAL, VAR, TYPE, GIVEN)

Expand Down
9 changes: 7 additions & 2 deletions compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
}

private def Modifiers(sym: Symbol): Modifiers = untpd.Modifiers(
sym.flags & (if (sym.isType) ModifierFlags | VarianceFlags else ModifierFlags),
sym.flags & (if (sym.isType) ModifierFlags | VarianceFlags | SuperTrait else ModifierFlags),
if (sym.privateWithin.exists) sym.privateWithin.asType.name else tpnme.EMPTY,
sym.annotations.filterNot(ann => dropAnnotForModText(ann.symbol)).map(_.tree))

Expand Down Expand Up @@ -839,7 +839,11 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
}

protected def templateText(tree: TypeDef, impl: Template): Text = {
val decl = modText(tree.mods, tree.symbol, keywordStr(if (tree.mods.is(Trait)) "trait" else "class"), isType = true)
val kw =
if tree.mods.is(SuperTrait) then "super trait"
else if tree.mods.is(Trait) then "trait"
else "class"
val decl = modText(tree.mods, tree.symbol, keywordStr(kw), isType = true)
( decl ~~ typeText(nameIdText(tree)) ~ withEnclosingDef(tree) { toTextTemplate(impl) }
// ~ (if (tree.hasType && printDebug) i"[decls = ${tree.symbol.info.decls}]" else "") // uncomment to enable
)
Expand Down Expand Up @@ -945,6 +949,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
else if (sym.isPackageObject) "package object"
else if (flags.is(Module) && flags.is(Case)) "case object"
else if (sym.isClass && flags.is(Case)) "case class"
else if sym.isClass && flags.is(SuperTrait) then "super trait"
else super.keyString(sym)
}

Expand Down
17 changes: 1 addition & 16 deletions compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1458,19 +1458,6 @@ class Namer { typer: Typer =>
// println(s"owner = ${sym.owner}, decls = ${sym.owner.info.decls.show}")
def isInlineVal = sym.isOneOf(FinalOrInline, butNot = Method | Mutable)

def isEnumValue(tp: Type) = tp.typeSymbol == defn.EnumValueClass

// Drop EnumValue parents from inferred types of enum constants
def dropEnumValue(tp: Type): Type = tp.dealias match
case tpd @ AndType(tp1, tp2) =>
if isEnumValue(tp1) then tp2
else if isEnumValue(tp2) then tp1
else
val tpw = tpd.derivedAndType(dropEnumValue(tp1), dropEnumValue(tp2))
if tpw ne tpd then tpw else tp
case _ =>
tp

// Widen rhs type and eliminate `|' but keep ConstantTypes if
// definition is inline (i.e. final in Scala2) and keep module singleton types
// instead of widening to the underlying module class types.
Expand All @@ -1479,9 +1466,7 @@ class Namer { typer: Typer =>
def widenRhs(tp: Type): Type =
tp.widenTermRefExpr.simplified match
case ctp: ConstantType if isInlineVal => ctp
case tp =>
val tp1 = ctx.typeComparer.widenInferred(tp, rhsProto)
if sym.is(Enum) then dropEnumValue(tp1) else tp1
case tp => ctx.typeComparer.widenInferred(tp, rhsProto)

// Replace aliases to Unit by Unit itself. If we leave the alias in
// it would be erased to BoxedUnit.
Expand Down
2 changes: 1 addition & 1 deletion docs/docs/internals/syntax.md
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ VarDef ::= PatDef
DefDef ::= DefSig [‘:’ Type] ‘=’ Expr DefDef(_, name, tparams, vparamss, tpe, expr)
| ‘this’ DefParamClause DefParamClauses ‘=’ ConstrExpr DefDef(_, <init>, Nil, vparamss, EmptyTree, expr | Block)
TmplDef ::= ([‘case’] ‘class’ | ‘trait’) ClassDef
TmplDef ::= ([‘case’] ‘class’ | [‘super’] ‘trait’) ClassDef
| [‘case’] ‘object’ ObjectDef
| ‘enum’ EnumDef
| ‘given’ GivenDef
Expand Down
10 changes: 10 additions & 0 deletions library/src-bootstrapped/scala/runtime/EnumValue.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package scala.runtime

super trait EnumValue extends Product, Serializable:
override def canEqual(that: Any) = this eq that.asInstanceOf[AnyRef]
override def productArity: Int = 0
override def productPrefix: String = toString
override def productElement(n: Int): Any =
throw IndexOutOfBoundsException(n.toString)
override def productElementName(n: Int): String =
throw IndexOutOfBoundsException(n.toString)
File renamed without changes.
6 changes: 5 additions & 1 deletion tasty/src/dotty/tools/tasty/TastyFormat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ Standard-Section: "ASTs" TopLevelStat*
STATIC -- Mapped to static Java member
OBJECT -- An object or its class
TRAIT -- A trait
SUPERTRAIT -- A super trait
ENUM -- A enum class or enum case
LOCAL -- private[this] or protected[this], used in conjunction with PRIVATE or PROTECTED
SYNTHETIC -- Generated by Scala compiler
Expand Down Expand Up @@ -359,6 +360,7 @@ object TastyFormat {
final val OPEN = 40
final val PARAMEND = 41
final val PARAMalias = 42
final val SUPERTRAIT = 43

// Cat. 2: tag Nat

Expand Down Expand Up @@ -473,7 +475,7 @@ object TastyFormat {

/** Useful for debugging */
def isLegalTag(tag: Int): Boolean =
firstSimpleTreeTag <= tag && tag <= PARAMalias ||
firstSimpleTreeTag <= tag && tag <= SUPERTRAIT ||
firstNatTreeTag <= tag && tag <= RENAMED ||
firstASTTreeTag <= tag && tag <= BOUNDED ||
firstNatASTTreeTag <= tag && tag <= NAMEDARG ||
Expand Down Expand Up @@ -502,6 +504,7 @@ object TastyFormat {
| STATIC
| OBJECT
| TRAIT
| SUPERTRAIT
| ENUM
| LOCAL
| SYNTHETIC
Expand Down Expand Up @@ -562,6 +565,7 @@ object TastyFormat {
case STATIC => "STATIC"
case OBJECT => "OBJECT"
case TRAIT => "TRAIT"
case SUPERTRAIT => "SUPERTRAIT"
case ENUM => "ENUM"
case LOCAL => "LOCAL"
case SYNTHETIC => "SYNTHETIC"
Expand Down
32 changes: 32 additions & 0 deletions tests/neg-custom-args/fatal-warnings/supertraits.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
sealed super trait TA
sealed super trait TB
case object a extends TA, TB
case object b extends TA, TB

object Test:

def choose0[X](x: X, y: X): X = x
def choose1[X <: TA](x: X, y: X): X = x
def choose2[X <: TB](x: X, y: X): X = x
def choose3[X <: Product](x: X, y: X): X = x
def choose4[X <: TA & TB](x: X, y: X): X = x

choose0(a, b) match
case _: TA => ???
case _: TB => ???

choose1(a, b) match
case _: TA => ???
case _: TB => ??? // error: unreachable

choose2(a, b) match
case _: TB => ???
case _: TA => ??? // error: unreachable

choose3(a, b) match
case _: Product => ???
case _: TA => ??? // error: unreachable

choose4(a, b) match
case _: (TA & TB) => ???
case _: Product => ??? // error: unreachable
13 changes: 13 additions & 0 deletions tests/neg/supertraits.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
super trait S
trait A
class B extends A, S
class C extends A, S

val x = if ??? then B() else C()
val x1: S = x // error

case object a
case object b
val y = if ??? then a else b
val y1: Product = y // error
val y2: Serializable = y // error
2 changes: 1 addition & 1 deletion tests/run/java-intersection/Test_2.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
object Test {
def main(args: Array[String]): Unit = {
val a = new A_1
val x = new java.io.Serializable {}
val x: java.io.Serializable = new java.io.Serializable {}
a.foo(x)
}
}

0 comments on commit ff9798d

Please sign in to comment.