Skip to content

Commit

Permalink
Synthesise Mirror.Sum for nested hierarchies
Browse files Browse the repository at this point in the history
  • Loading branch information
dwijnand committed Jul 7, 2021
1 parent 8f3fdf5 commit 3629c66
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 14 deletions.
8 changes: 6 additions & 2 deletions compiler/src/dotty/tools/dotc/transform/SymUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ object SymUtils:
* - none of its children are anonymous classes
* - all of its children are addressable through a path from the parent class
* and also the location of the generated mirror.
* - all of its children are generic products or singletons
* - all of its children are generic products, singletons, or generic sums themselves.
*/
def whyNotGenericSum(declScope: Symbol)(using Context): String =
if (!self.is(Sealed))
Expand All @@ -116,7 +116,11 @@ object SymUtils:
else {
val s = child.whyNotGenericProduct
if (s.isEmpty) s
else i"its child $child is not a generic product because $s"
else if (child.is(Sealed)) {
val s = child.whyNotGenericSum(if child.useCompanionAsMirror then child.linkedClass else ctx.owner)
if (s.isEmpty) s
else i"its child $child is not a generic sum because $s"
} else i"its child $child is not a generic product because $s"
}
}
if (children.isEmpty) "it does not have subclasses"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
val pat = Typed(untpd.Ident(nme.WILDCARD).withType(patType), TypeTree(patType))
CaseDef(pat, EmptyTree, Literal(Constant(idx)))
}
Match(param, cases)
Match(param.annotated(New(defn.UncheckedAnnot.typeRef, Nil)), cases)
}

/** - If `impl` is the companion of a generic sum, add `deriving.Mirror.Sum` parent
Expand Down
22 changes: 11 additions & 11 deletions compiler/src/dotty/tools/dotc/typer/Synthesizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,12 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
val elemLabels = cls.children.map(c => ConstantType(Constant(c.name.toString)))

def solve(sym: Symbol): Type = sym match
case caseClass: ClassSymbol =>
assert(caseClass.is(Case))
if caseClass.is(Module) then
caseClass.sourceModule.termRef
case childClass: ClassSymbol =>
assert(childClass.isOneOf(Case | Sealed))
if childClass.is(Module) then
childClass.sourceModule.termRef
else
caseClass.primaryConstructor.info match
childClass.primaryConstructor.info match
case info: PolyType =>
// Compute the the full child type by solving the subtype constraint
// `C[X1, ..., Xn] <: P`, where
Expand All @@ -310,13 +310,13 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
case tp => tp
resType <:< target
val tparams = poly.paramRefs
val variances = caseClass.typeParams.map(_.paramVarianceSign)
val variances = childClass.typeParams.map(_.paramVarianceSign)
val instanceTypes = tparams.lazyZip(variances).map((tparam, variance) =>
TypeComparer.instanceType(tparam, fromBelow = variance < 0))
resType.substParams(poly, instanceTypes)
instantiate(using ctx.fresh.setExploreTyperState().setOwner(caseClass))
instantiate(using ctx.fresh.setExploreTyperState().setOwner(childClass))
case _ =>
caseClass.typeRef
childClass.typeRef
case child => child.termRef
end solve

Expand All @@ -331,9 +331,9 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
(mirroredType, elems)

val mirrorType =
mirrorCore(defn.Mirror_SumClass, monoType, mirroredType, cls.name, formal)
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(TypeOps.nestedPairs(elemLabels)))
mirrorCore(defn.Mirror_SumClass, monoType, mirroredType, cls.name, formal)
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(TypeOps.nestedPairs(elemLabels)))
val mirrorRef =
if useCompanion then companionPath(mirroredType, span)
else anonymousMirror(monoType, ExtendsSumMirror, span)
Expand Down
1 change: 1 addition & 0 deletions compiler/test/dotty/tools/dotc/CompilationTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ class CompilationTests {
compileFile("tests/run-custom-args/no-useless-forwarders.scala", defaultOptions and "-Xmixin-force-forwarders:false"),
compileFile("tests/run-custom-args/defaults-serizaliable-no-forwarders.scala", defaultOptions and "-Xmixin-force-forwarders:false"),
compileFilesInDir("tests/run-custom-args/erased", defaultOptions.and("-language:experimental.erasedDefinitions")),
compileFilesInDir("tests/run-custom-args/fatal-warnings", defaultOptions.and("-Xfatal-warnings")),
compileFilesInDir("tests/run-deep-subtype", allowDeepSubtypes),
compileFilesInDir("tests/run", defaultOptions.and("-Ysafe-init"))
).checkRuns()
Expand Down
141 changes: 141 additions & 0 deletions tests/run-custom-args/fatal-warnings/i11050.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import scala.compiletime.*
import scala.deriving.*

object OriginalReport:
sealed trait TreeValue
sealed trait SubLevel extends TreeValue
case class Leaf1(value: String) extends TreeValue
case class Leaf2(value: Int) extends SubLevel

// Variants from the initial failure in akka.event.LogEvent
object FromAkkaCB:
sealed trait A
sealed trait B extends A
sealed trait C extends A
case class D() extends B, C
case class E() extends C, B

object FromAkkaCB2:
sealed trait A
sealed trait N extends A
case class B() extends A
case class C() extends A, N

object FromAkkaCB3:
sealed trait A
case class B() extends A
case class C() extends A
class D extends C // ignored pattern: class extending a case class

object NoUnreachableWarnings:
sealed trait Top
object Top

final case class MiddleA() extends Top with Bottom
final case class MiddleB() extends Top with Bottom
final case class MiddleC() extends Top with Bottom

sealed trait Bottom extends Top

object FromAkkaCB4:
sealed trait LogEvent
object LogEvent
case class Error() extends LogEvent
class Error2() extends Error() with LogEventWithMarker // ignored pattern
case class Warning() extends LogEvent
sealed trait LogEventWithMarker extends LogEvent // must be defined late

object FromAkkaCB4simpler:
sealed trait LogEvent
object LogEvent
case class Error() extends LogEvent
class Error2() extends LogEventWithMarker // not a case class
case class Warning() extends LogEvent
sealed trait LogEventWithMarker extends LogEvent

object Test:
def main(args: Array[String]): Unit =
testOriginalReport()
testFromAkkaCB()
testFromAkkaCB2()
end main

def testOriginalReport() =
import OriginalReport._
val m = summon[Mirror.SumOf[TreeValue]]
given Show[TreeValue] = Show.derived[TreeValue]
val leaf1 = Leaf1("1")
val leaf2 = Leaf2(2)

assertEq(List(leaf1, leaf2).map(m.ordinal), List(1, 0))
assertShow[TreeValue](leaf1, "[1] Leaf1(value = \"1\")")
assertShow[TreeValue](leaf2, "[0] [0] Leaf2(value = 2)")
end testOriginalReport

def testFromAkkaCB() =
import FromAkkaCB._
val m = summon[Mirror.SumOf[A]]
given Show[A] = Show.derived[A]
val d = D()
val e = E()

assertEq(List(d, e).map(m.ordinal), List(0, 0))
assertShow[A](d, "[0] [0] D")
assertShow[A](e, "[0] [1] E")
end testFromAkkaCB

def testFromAkkaCB2() =
import FromAkkaCB2._
val m = summon[Mirror.SumOf[A]]
val n = summon[Mirror.SumOf[N]]
given Show[A] = Show.derived[A]
val b = B()
val c = C()

assertEq(List(b, c).map(m.ordinal), List(1, 0))
assertShow[A](b, "[1] B")
assertShow[A](c, "[0] [0] C")
end testFromAkkaCB2

def assertEq[A](obt: A, exp: A) = assert(obt == exp, s"$obt != $exp (obtained != expected)")
def assertShow[A: Show](x: A, s: String) = assertEq(Show.show(x), s)
end Test

trait Show[-T]:
def show(x: T): String

object Show:
given Show[Int] with { def show(x: Int) = s"$x" }
given Show[Char] with { def show(x: Char) = s"'$x'" }
given Show[String] with { def show(x: String) = s"$"$x$"" }

inline def show[T](x: T): String = summonInline[Show[T]].show(x)

transparent inline def derived[T](implicit ev: Mirror.Of[T]): Show[T] = new {
def show(x: T): String = inline ev match {
case m: Mirror.ProductOf[T] => showProduct(x.asInstanceOf[Product], m)
case m: Mirror.SumOf[T] => showCases[m.MirroredElemTypes](0)(x, m.ordinal(x))
}
}

transparent inline def showProduct[T](x: Product, m: Mirror.ProductOf[T]): String =
constValue[m.MirroredLabel] + showElems[m.MirroredElemTypes, m.MirroredElemLabels](0, Nil)(x)

transparent inline def showElems[Elems <: Tuple, Labels <: Tuple](n: Int, elems: List[String])(x: Product): String =
inline (erasedValue[Labels], erasedValue[Elems]) match {
case _: (label *: labels, elem *: elems) =>
val value = show(x.productElement(n).asInstanceOf[elem])
showElems[elems, labels](n + 1, s"${constValue[label]} = $value" :: elems)(x)
case _: (EmptyTuple, EmptyTuple) =>
if elems.isEmpty then "" else elems.mkString(s"(", ", ", ")")
}

transparent inline def showCases[Alts <: Tuple](n: Int)(x: Any, ord: Int): String =
inline erasedValue[Alts] match {
case _: (alt *: alts) =>
if (ord == n) summonFrom {
case m: Mirror.Of[`alt`] => s"[$ord] " + derived[alt](using m).show(x.asInstanceOf[alt])
} else showCases[alts](n + 1)(x, ord)
case _: EmptyTuple => throw new MatchError(x)
}
end Show

0 comments on commit 3629c66

Please sign in to comment.