Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reimplement constraint merging for correctness #13292

Merged
merged 4 commits into from
Aug 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions compiler/src/dotty/tools/dotc/core/Constraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,6 @@ abstract class Constraint extends Showable {
*/
def uninstVars: collection.Seq[TypeVar]

/** The weakest constraint that subsumes both this constraint and `other`.
* The constraints should be _compatible_, meaning that a type lambda
* occurring in both constraints is associated with the same typevars in each.
*
* @param otherHasErrors If true, handle incompatible constraints by
* returning an approximate constraint, instead of
* failing with an exception
*/
def & (other: Constraint, otherHasErrors: Boolean)(using Context): Constraint

/** Whether `tl` is present in both `this` and `that` but is associated with
* different TypeVars there, meaning that the constraints cannot be merged.
*/
Expand All @@ -183,7 +173,4 @@ abstract class Constraint extends Showable {
* of athe type lambda that is associated with the typevar itself.
*/
def checkConsistentVars()(using Context): Unit

/** A string describing the constraint's contents without a header or trailer */
def contentsToString(using Context): String
}
99 changes: 10 additions & 89 deletions compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ object OrderingConstraint {
type ParamOrdering = ArrayValuedMap[List[TypeParamRef]]

/** A new constraint with given maps */
private def newConstraint(boundsMap: ParamBounds, lowerMap: ParamOrdering, upperMap: ParamOrdering)(using Context) : OrderingConstraint = {
val result = new OrderingConstraint(boundsMap, lowerMap, upperMap)
ctx.run.recordConstraintSize(result, result.boundsMap.size)
result
}
private def newConstraint(boundsMap: ParamBounds, lowerMap: ParamOrdering, upperMap: ParamOrdering)(using Context) : OrderingConstraint =
if boundsMap.isEmpty && lowerMap.isEmpty && upperMap.isEmpty then
empty
else
val result = new OrderingConstraint(boundsMap, lowerMap, upperMap)
ctx.run.recordConstraintSize(result, result.boundsMap.size)
result

/** A lens for updating a single entry array in one of the three constraint maps */
abstract class ConstraintLens[T <: AnyRef: ClassTag] {
Expand Down Expand Up @@ -457,48 +459,6 @@ class OrderingConstraint(private val boundsMap: ParamBounds,

// ----------- Joins -----------------------------------------------------

def & (other: Constraint, otherHasErrors: Boolean)(using Context): OrderingConstraint = {

def merge[T](m1: ArrayValuedMap[T], m2: ArrayValuedMap[T], join: (T, T) => T): ArrayValuedMap[T] = {
var merged = m1
def mergeArrays(xs1: Array[T], xs2: Array[T]) = {
val xs = xs1.clone
for (i <- xs.indices) xs(i) = join(xs1(i), xs2(i))
xs
}
m2.foreachBinding { (poly, xs2) =>
merged = merged.updated(poly,
if (m1.contains(poly)) mergeArrays(m1(poly), xs2) else xs2)
}
merged
}

def mergeParams(ps1: List[TypeParamRef], ps2: List[TypeParamRef]) =
ps2.foldLeft(ps1)((ps1, p2) => if (ps1.contains(p2)) ps1 else p2 :: ps1)

// Must be symmetric
def mergeEntries(e1: Type, e2: Type): Type =
(e1, e2) match {
case _ if e1 eq e2 => e1
case (e1: TypeBounds, e2: TypeBounds) => e1 & e2
case (e1: TypeBounds, _) if e1 contains e2 => e2
case (_, e2: TypeBounds) if e2 contains e1 => e1
case (tv1: TypeVar, tv2: TypeVar) if tv1 eq tv2 => e1
case _ =>
if (otherHasErrors)
e1
else
throw new AssertionError(i"cannot merge $this with $other, mergeEntries($e1, $e2) failed")
}

val that = other.asInstanceOf[OrderingConstraint]

new OrderingConstraint(
merge(this.boundsMap, that.boundsMap, mergeEntries),
merge(this.lowerMap, that.lowerMap, mergeParams),
merge(this.upperMap, that.upperMap, mergeParams))
}.showing(i"constraint merge $this with $other = $result", constr)

def hasConflictingTypeVarsFor(tl: TypeLambda, that: Constraint): Boolean =
contains(tl) && that.contains(tl) &&
// Since TypeVars are allocated in bulk for each type lambda, we only have
Expand Down Expand Up @@ -641,49 +601,10 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
upperMap.foreachBinding((_, paramss) => paramss.foreach(_.foreach(checkClosedType(_, "upper"))))
end checkClosed

// ---------- toText -----------------------------------------------------

private def contentsToText(printer: Printer): Text =
//Printer.debugPrintUnique = true
def entryText(tp: Type) = tp match {
case tp: TypeBounds =>
tp.toText(printer)
case _ =>
" := " ~ tp.toText(printer)
}
val indent = 3
val uninstVarsText = " uninstantiated variables: " ~
Text(uninstVars.map(_.toText(printer)), ", ")
val constrainedText =
" constrained types: " ~ Text(domainLambdas map (_.toText(printer)), ", ")
val boundsText =
" bounds: " ~ {
val assocs =
for (param <- domainParams)
yield (" " * indent) ~ param.toText(printer) ~ entryText(entry(param))
Text(assocs, "\n")
}
val orderingText =
" ordering: " ~ {
val deps =
for {
param <- domainParams
ups = minUpper(param)
if ups.nonEmpty
}
yield
(" " * indent) ~ param.toText(printer) ~ " <: " ~
Text(ups.map(_.toText(printer)), ", ")
Text(deps, "\n")
}
//Printer.debugPrintUnique = false
Text.lines(List(uninstVarsText, constrainedText, boundsText, orderingText))
// ---------- Printing -----------------------------------------------------

override def toText(printer: Printer): Text =
Text.lines(List("Constraint(", contentsToText(printer), ")"))

def contentsToString(using Context): String =
contentsToText(ctx.printer).show
printer.toText(this)

override def toString: String = {
def entryText(tp: Type): String = tp match {
Expand All @@ -692,7 +613,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
}
val constrainedText =
" constrained types = " + domainLambdas.mkString("\n")
val boundsText = domainLambdas
val boundsText =
" bounds = " + {
val assocs =
for (param <- domainParams)
Expand Down
38 changes: 35 additions & 3 deletions compiler/src/dotty/tools/dotc/core/TyperState.scala
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,43 @@ class TyperState() {
*/
def mergeConstraintWith(that: TyperState)(using Context): Unit =
that.ensureNotConflicting(constraint)
constraint = constraint & (that.constraint, otherHasErrors = that.reporter.errorsReported)
for tvar <- constraint.uninstVars do
if !isOwnedAnywhere(this, tvar) then includeVar(tvar)

val comparingCtx =
if ctx.typerState == this then ctx
else ctx.fresh.setTyperState(this)

comparing(typeComparer =>
val other = that.constraint
val res = other.domainLambdas.forall(tl =>
// Integrate the type lambdas from `other`
constraint.contains(tl) || other.isRemovable(tl) || {
val tvars = tl.paramRefs.map(other.typeVarOfParam(_)).collect { case tv: TypeVar => tv }
tvars.foreach(tvar => if !tvar.inst.exists && !isOwnedAnywhere(this, tvar) then includeVar(tvar))
typeComparer.addToConstraint(tl, tvars)
}) &&
// Integrate the additional constraints on type variables from `other`
constraint.uninstVars.forall(tv =>
val p = tv.origin
val otherLos = other.lower(p)
val otherHis = other.upper(p)
val otherEntry = other.entry(p)
( (otherLos eq constraint.lower(p)) || otherLos.forall(_ <:< p)) &&
( (otherHis eq constraint.upper(p)) || otherHis.forall(p <:< _)) &&
((otherEntry eq constraint.entry(p)) || otherEntry.match
case NoType =>
true
case tp: TypeBounds =>
tp.contains(tv)
case tp =>
tv =:= tp
)
)
assert(res || ctx.reporter.errorsReported, i"cannot merge $constraint with $other.")
)(using comparingCtx)

for tl <- constraint.domainLambdas do
if constraint.isRemovable(tl) then constraint = constraint.remove(tl)
end mergeConstraintWith

/** Take ownership of `tvar`.
*
Expand Down
41 changes: 41 additions & 0 deletions compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,47 @@ class PlainPrinter(_ctx: Context) extends Printer {
case _ => "{...}"
s"import $exprStr.$selectorStr"

def toText(c: OrderingConstraint): Text =
val savedConstraint = ctx.typerState.constraint
try
// The current TyperState constraint determines how type variables are printed
ctx.typerState.constraint = c
def entryText(tp: Type) = tp match {
case tp: TypeBounds =>
toText(tp)
case _ =>
" := " ~ toText(tp)
}
val indent = 3
val uninstVarsText = " uninstantiated variables: " ~
Text(c.uninstVars.map(toText), ", ")
val constrainedText =
" constrained types: " ~ Text(c.domainLambdas.map(toText), ", ")
val boundsText =
" bounds: " ~ {
val assocs =
for (param <- c.domainParams)
yield (" " * indent) ~ toText(param) ~ entryText(c.entry(param))
Text(assocs, "\n")
}
val orderingText =
" ordering: " ~ {
val deps =
for {
param <- c.domainParams
ups = c.minUpper(param)
if ups.nonEmpty
}
yield
(" " * indent) ~ toText(param) ~ " <: " ~
Text(ups.map(toText), ", ")
Text(deps, "\n")
}
//Printer.debugPrintUnique = false
Text.lines(List(uninstVarsText, constrainedText, boundsText, orderingText))
finally
ctx.typerState.constraint = savedConstraint

def plain: PlainPrinter = this

protected def keywordStr(text: String): String = coloredStr(text, SyntaxHighlighting.KeywordColor)
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/printing/Printer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ abstract class Printer {
/** Textual representation of info relating to an import clause */
def toText(result: ImportInfo): Text

/** Textual representation of a constraint */
def toText(c: OrderingConstraint): Text

/** Render element within highest precedence */
def toTextLocal(elem: Showable): Text =
atPrec(DotPrec) { elem.toText(this) }
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ object ErrorReporting {
"the empty constraint"
else
i"""a constraint with:
|${c.contentsToString}"""
|$c"""
i"""
|${TypeComparer.explained(_.isSubType(found, expected), header)}
|
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ object ProtoTypes {
def typedArgs(norm: (untpd.Tree, Int) => untpd.Tree = sameTree)(using Context): List[Tree] =
if state.typedArgs.size == args.length then state.typedArgs
else
val passedCtx = ctx
val passedTyperState = ctx.typerState
inContext(protoCtx.withUncommittedTyperState) {
val protoTyperState = ctx.typerState
Expand Down Expand Up @@ -409,8 +410,7 @@ object ProtoTypes {
tvar.instantiate(fromBelow = false)
case _ =>
}

passedTyperState.mergeConstraintWith(protoTyperState)
passedTyperState.mergeConstraintWith(protoTyperState)(using passedCtx)
end if
args1
}
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/util/SimpleIdentityMap.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import collection.mutable.ListBuffer
* It has linear complexity for `apply`, `updated`, and `remove`.
*/
abstract class SimpleIdentityMap[K <: AnyRef, +V >: Null <: AnyRef] extends (K => V) {
final def isEmpty: Boolean = this eq SimpleIdentityMap.myEmpty
def size: Int
def apply(k: K): V
def remove(k: K): SimpleIdentityMap[K, V]
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/util/Stats.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ import collection.mutable
aggregate()
println()
println(hits.toList.sortBy(_._2).map{ case (x, y) => s"$x -> $y" } mkString "\n")
hits.clear()
}
}
else op
Expand Down
54 changes: 54 additions & 0 deletions compiler/test/dotty/tools/dotc/core/ConstraintsTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package dotty.tools
package dotc.core

import vulpix.TestConfiguration

import dotty.tools.dotc.core.Contexts.{*, given}
import dotty.tools.dotc.core.Decorators.{*, given}
import dotty.tools.dotc.core.Symbols.*
import dotty.tools.dotc.core.Types.*
import dotty.tools.dotc.typer.ProtoTypes.constrained

import org.junit.Test

import dotty.tools.DottyTest

class ConstraintsTest:

@Test def mergeParamsTransitivity: Unit =
inCompilerContext(TestConfiguration.basicClasspath,
scalaSources = "trait A { def foo[S, T, R]: Any }") {
val tp = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda])
val List(s, t, r) = tp.paramRefs

val innerCtx = ctx.fresh.setExploreTyperState()
inContext(innerCtx) {
s <:< t
}

t <:< r

ctx.typerState.mergeConstraintWith(innerCtx.typerState)
assert(s frozen_<:< r,
i"Merging constraints `?S <: ?T` and `?T <: ?R` should result in `?S <:< ?R`: ${ctx.typerState.constraint}")
}
end mergeParamsTransitivity

@Test def mergeBoundsTransitivity: Unit =
inCompilerContext(TestConfiguration.basicClasspath,
scalaSources = "trait A { def foo[S, T]: Any }") {
val tp = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda])
val List(s, t) = tp.paramRefs

val innerCtx = ctx.fresh.setExploreTyperState()
inContext(innerCtx) {
s <:< t
}

defn.IntType <:< s

ctx.typerState.mergeConstraintWith(innerCtx.typerState)
assert(defn.IntType frozen_<:< t,
i"Merging constraints `?S <: ?T` and `Int <: ?S` should result in `Int <:< ?T`: ${ctx.typerState.constraint}")
}
end mergeBoundsTransitivity
37 changes: 37 additions & 0 deletions tests/pos/i12730.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
class ComponentSimple

class Props {
def apply(props: Any): Any = ???
}

class Foo[C] {
def build: ComponentSimple = ???
}

class Bar[E] {
def render(r: E => Any): Unit = {}
}

trait Conv[A, B] {
def apply(a: A): B
}

object Test {
def toComponentCtor[F](c: ComponentSimple): Props = ???

def defaultToNoBackend[G, H](ev: G => Foo[H]): Conv[Foo[H], Bar[H]] = ???

def conforms[A]: A => A = ???

def problem = Main // crashes

def foo[H]: Foo[H] = ???

val NameChanger =
foo
.build

val Main =
defaultToNoBackend(conforms).apply(foo)
.render(_ => toComponentCtor(NameChanger)(13))
}