Skip to content

Commit

Permalink
Instantiate Type Vars in completion labels of extension methods (#18914)
Browse files Browse the repository at this point in the history
This PR fixes the completion labels for extension methods and old style
completion methods.

It required an API change of `Completion.scala` as it didn't contain
enough information to properly create labels on presentation compiler
side. Along with this following parts were rewritten to avoid
recomputation:
- computation of adjustedPath which is necessary for extension construct
completions,
- computation of prefix is now unified across the metals and
presentation compilers,
- completionKind was changed, and kind Scope, Member are now part of
completion mode.


The biggest change is basically added support for old style extension
methods.
I found out that the type var was not instantiated after
`ImplicitSearch` computation, and was only calculated at later stages. I
don't have enough knowledge about Inference and Typer to know whether
this is a bug or rather an intended behaviour but I managed to solve it
without touching the typer:
```scala

      def tryToInstantiateTypeVars(conversionTarget: SearchSuccess): Type =
        try
          val typingCtx = ctx.fresh
          inContext(typingCtx):
            val methodRefTree = ref(conversionTarget.ref, needLoad = false)
            val convertedTree = ctx.typer.typedAheadExpr(untpd.Apply(untpd.TypedSplice(methodRefTree), untpd.TypedSplice(qual) :: Nil))
            Inferencing.fullyDefinedType(convertedTree.tpe, "", pos)
        catch
          case error => conversionTarget.tree.tpe // fallback to not fully defined type
```
[Cherry-picked 95266f2]
  • Loading branch information
rochala authored and WojciechMazur committed Jun 28, 2024
1 parent 26006ce commit 43003c7
Show file tree
Hide file tree
Showing 20 changed files with 525 additions and 375 deletions.
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Flags.scala
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ object Flags {
val (AccessorOrSealed @ _, Accessor @ _, Sealed @ _) = newFlags(11, "<accessor>", "sealed")

/** A mutable var, an open class */
val (MutableOrOpen @ __, Mutable @ _, Open @ _) = newFlags(12, "mutable", "open")
val (MutableOrOpen @ _, Mutable @ _, Open @ _) = newFlags(12, "mutable", "open")

/** Symbol is local to current class (i.e. private[this] or protected[this]
* pre: Private or Protected are also set
Expand Down
272 changes: 159 additions & 113 deletions compiler/src/dotty/tools/dotc/interactive/Completion.scala

Large diffs are not rendered by default.

29 changes: 29 additions & 0 deletions compiler/src/dotty/tools/dotc/interactive/Interactive.scala
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,21 @@ object Interactive {
false
}


/** Some information about the trees is lost after Typer such as Extension method construct
* is expanded into methods. In order to support completions in those cases
* we have to rely on untyped trees and only when types are necessary use typed trees.
*/
def resolveTypedOrUntypedPath(tpdPath: List[Tree], pos: SourcePosition)(using Context): List[untpd.Tree] =
lazy val untpdPath: List[untpd.Tree] = NavigateAST
.pathTo(pos.span, List(ctx.compilationUnit.untpdTree), true).collect:
case untpdTree: untpd.Tree => untpdTree

tpdPath match
case (_: Bind) :: _ => tpdPath
case (_: untpd.TypTree) :: _ => tpdPath
case _ => untpdPath

/**
* Is this tree using a renaming introduced by an import statement or an alias for `this`?
*
Expand All @@ -436,6 +451,20 @@ object Interactive {
def sameName(n0: Name, n1: Name): Boolean =
n0.stripModuleClassSuffix.toTermName eq n1.stripModuleClassSuffix.toTermName

/** https://scala-lang.org/files/archive/spec/3.4/02-identifiers-names-and-scopes.html
* import java.lang.*
* {
* import scala.*
* {
* import Predef.*
* { /* source */ }
* }
* }
*/
def isImportedByDefault(sym: Symbol)(using Context): Boolean =
val owner = sym.effectiveOwner
owner == defn.ScalaPredefModuleClass || owner == defn.ScalaPackageClass || owner == defn.JavaLangPackageClass

private[interactive] def safely[T](op: => List[T]): List[T] =
try op catch { case ex: TypeError => Nil }
}
Expand Down
100 changes: 94 additions & 6 deletions language-server/test/dotty/tools/languageserver/CompletionTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1019,33 +1019,33 @@ class CompletionTest {
| val x = Bar.${m1}"""
.completion(
("getClass", Method, "[X0 >: Foo.Bar.type](): Class[? <: X0]"),
("ensuring", Method, "(cond: Boolean): A"),
("ensuring", Method, "(cond: Boolean): Foo.Bar.type"),
("##", Method, "=> Int"),
("nn", Method, "=> Foo.Bar.type"),
("==", Method, "(x$0: Any): Boolean"),
("ensuring", Method, "(cond: Boolean, msg: => Any): A"),
("ensuring", Method, "(cond: Boolean, msg: => Any): Foo.Bar.type"),
("ne", Method, "(x$0: Object): Boolean"),
("valueOf", Method, "($name: String): Foo.Bar"),
("equals", Method, "(x$0: Any): Boolean"),
("wait", Method, "(x$0: Long): Unit"),
("hashCode", Method, "(): Int"),
("notifyAll", Method, "(): Unit"),
("values", Method, "=> Array[Foo.Bar]"),
("", Method, "[B](y: B): (A, B)"),
("", Method, "[B](y: B): (Foo.Bar.type, B)"),
("!=", Method, "(x$0: Any): Boolean"),
("fromOrdinal", Method, "(ordinal: Int): Foo.Bar"),
("asInstanceOf", Method, "[X0]: X0"),
("->", Method, "[B](y: B): (A, B)"),
("->", Method, "[B](y: B): (Foo.Bar.type, B)"),
("wait", Method, "(x$0: Long, x$1: Int): Unit"),
("`back-tick`", Field, "Foo.Bar"),
("notify", Method, "(): Unit"),
("formatted", Method, "(fmtstr: String): String"),
("ensuring", Method, "(cond: A => Boolean, msg: => Any): A"),
("ensuring", Method, "(cond: Foo.Bar.type => Boolean, msg: => Any): Foo.Bar.type"),
("wait", Method, "(): Unit"),
("isInstanceOf", Method, "[X0]: Boolean"),
("`match`", Field, "Foo.Bar"),
("toString", Method, "(): String"),
("ensuring", Method, "(cond: A => Boolean): A"),
("ensuring", Method, "(cond: Foo.Bar.type => Boolean): Foo.Bar.type"),
("eq", Method, "(x$0: Object): Boolean"),
("synchronized", Method, "[X0](x$0: X0): X0")
)
Expand Down Expand Up @@ -1576,6 +1576,61 @@ class CompletionTest {
|"""
.completion(m1, Set(("TTT", Field, "T.TTT")))

@Test def properTypeVariable: Unit =
code"""|object M:
| List(1,2,3).filterNo$m1
|"""
.completion(m1, Set(("filterNot", Method, "(p: Int => Boolean): List[Int]")))

@Test def properTypeVariableForExtensionMethods: Unit =
code"""|object M:
| extension [T](x: List[T]) def test(aaa: T): T = ???
| List(1,2,3).tes$m1
|
|"""
.completion(m1, Set(("test", Method, "(aaa: Int): Int")))

@Test def properTypeVariableForExtensionMethodsByName: Unit =
code"""|object M:
| extension [T](xs: List[T]) def test(p: T => Boolean): List[T] = ???
| List(1,2,3).tes$m1
|"""
.completion(m1, Set(("test", Method, "(p: Int => Boolean): List[Int]")))

@Test def genericExtensionTypeParameterInference: Unit =
code"""|object M:
| extension [T](xs: T) def test(p: T): T = ???
| 3.tes$m1
|"""
.completion(m1, Set(("test", Method, "(p: Int): Int")))

@Test def genericExtensionTypeParameterInferenceByName: Unit =
code"""|object M:
| extension [T](xs: T) def test(p: T => Boolean): T = ???
| 3.tes$m1
|"""
.completion(m1, Set(("test", Method, "(p: Int => Boolean): Int")))

@Test def properTypeVariableForImplicitDefs: Unit =
code"""|object M:
| implicit class ListUtils[T](xs: List[T]) {
| def test(p: T => Boolean): List[T] = ???
| }
| List(1,2,3).tes$m1
|"""
.completion(m1, Set(("test", Method, "(p: Int => Boolean): List[Int]")))

@Test def properTypeParameterForImplicitDefs: Unit =
code"""|object M:
| implicit class ListUtils[T](xs: T) {
| def test(p: T => Boolean): T = ???
| }
| new ListUtils(1).tes$m1
| 1.tes$m2
|"""
.completion(m1, Set(("test", Method, "(p: Int => Boolean): Int")))
.completion(m2, Set(("test", Method, "(p: Int => Boolean): Int")))

@Test def selectDynamic: Unit =
code"""|import scala.language.dynamics
|class Foo extends Dynamic {
Expand All @@ -1591,4 +1646,37 @@ class CompletionTest {
|"""
.completion(m1, Set(("selectDynamic", Method, "(field: String): Foo")))
.completion(m2, Set(("banana", Method, "=> Int")))

@Test def shadowedImport: Unit =
code"""|
|import Matches.*
|object Matches {
| val Number = "".r
|}
|object Main {
| Num$m1
|}
|""".completion(m1, Set(
("Number", Field, "scala.util.matching.Regex"),
("NumberFormatException", Module, "NumberFormatException"),
("Numeric", Field, "scala.math.Numeric")
))

@Test def shadowedImportType: Unit =
code"""|
|import Matches.*
|object Matches {
| val Number = "".r
|}
|object Main {
| val x: Num$m1
|}
|""".completion(m1, Set(
("Number", Class, "Number"),
("Number", Field, "scala.util.matching.Regex"),
("NumberFormatException", Module, "NumberFormatException"),
("NumberFormatException", Field, "NumberFormatException"),
("Numeric", Field, "Numeric"),
("Numeric", Field, "scala.math.Numeric")
))
}
36 changes: 27 additions & 9 deletions presentation-compiler/src/main/dotty/tools/pc/IndexedContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import dotty.tools.dotc.core.Names.*
import dotty.tools.dotc.core.Scopes.EmptyScope
import dotty.tools.dotc.core.Symbols.*
import dotty.tools.dotc.core.Types.*
import dotty.tools.dotc.interactive.Interactive
import dotty.tools.dotc.typer.ImportInfo
import dotty.tools.pc.IndexedContext.Result
import dotty.tools.pc.utils.MtagsEnrichments.*
Expand All @@ -31,16 +32,33 @@ sealed trait IndexedContext:
case Some(symbols) if symbols.exists(_ == sym) =>
Result.InScope
case Some(symbols)
if symbols
.exists(s => isTypeAliasOf(s, sym) || isTermAliasOf(s, sym)) =>
Result.InScope
if symbols.exists(s => isNotConflictingWithDefault(s, sym) || isTypeAliasOf(s, sym) || isTermAliasOf(s, sym)) =>
Result.InScope
// when all the conflicting symbols came from an old version of the file
case Some(symbols) if symbols.nonEmpty && symbols.forall(_.isStale) =>
Result.Missing
case Some(symbols) if symbols.nonEmpty && symbols.forall(_.isStale) => Result.Missing
case Some(_) => Result.Conflict
case None => Result.Missing
end lookupSym

/**
* Scala by default imports following packages:
* https://scala-lang.org/files/archive/spec/3.4/02-identifiers-names-and-scopes.html
* import java.lang.*
* {
* import scala.*
* {
* import Predef.*
* { /* source */ }
* }
* }
*
* This check is necessary for proper scope resolution, because when we compare symbols from
* index including the underlying type like scala.collection.immutable.List it actually
* is in current scope in form of type forwarder imported from Predef.
*/
private def isNotConflictingWithDefault(sym: Symbol, queriedSym: Symbol): Boolean =
sym.info.widenDealias =:= queriedSym.info.widenDealias && (Interactive.isImportedByDefault(sym))

final def hasRename(sym: Symbol, as: String): Boolean =
rename(sym) match
case Some(v) => v == as
Expand All @@ -49,15 +67,15 @@ sealed trait IndexedContext:
// detects import scope aliases like
// object Predef:
// val Nil = scala.collection.immutable.Nil
private def isTermAliasOf(termAlias: Symbol, sym: Symbol): Boolean =
private def isTermAliasOf(termAlias: Symbol, queriedSym: Symbol): Boolean =
termAlias.isTerm && (
sym.info match
queriedSym.info match
case clz: ClassInfo => clz.appliedRef =:= termAlias.info.resultType
case _ => false
)

private def isTypeAliasOf(alias: Symbol, sym: Symbol): Boolean =
alias.isAliasType && alias.info.metalsDealias.typeSymbol == sym
private def isTypeAliasOf(alias: Symbol, queriedSym: Symbol): Boolean =
alias.isAliasType && alias.info.metalsDealias.typeSymbol == queriedSym

final def isEmpty: Boolean = this match
case IndexedContext.Empty => true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,20 @@ import java.net.URI

import scala.meta.pc.OffsetParams

import dotty.tools.dotc.ast.tpd.*
import dotty.tools.dotc.ast.untpd.*
import dotty.tools.dotc.ast.untpd.ImportSelector
import dotty.tools.dotc.core.Contexts.*
import dotty.tools.dotc.core.StdNames.*
import dotty.tools.dotc.util.Chars
import dotty.tools.dotc.util.SourcePosition
import dotty.tools.dotc.util.Spans
import dotty.tools.dotc.interactive.Completion
import dotty.tools.pc.utils.MtagsEnrichments.*

import org.eclipse.lsp4j as l
import scala.annotation.tailrec

enum CompletionKind:
case Empty, Scope, Members

case class CompletionPos(
kind: CompletionKind,
start: Int,
end: Int,
query: String,
Expand All @@ -40,29 +37,21 @@ object CompletionPos:
def infer(
cursorPos: SourcePosition,
offsetParams: OffsetParams,
treePath: List[Tree]
adjustedPath: List[Tree]
)(using Context): CompletionPos =
infer(cursorPos, offsetParams.uri().nn, offsetParams.text().nn, treePath)
infer(cursorPos, offsetParams.uri().nn, offsetParams.text().nn, adjustedPath)

def infer(
cursorPos: SourcePosition,
uri: URI,
text: String,
treePath: List[Tree]
adjustedPath: List[Tree]
)(using Context): CompletionPos =
val start = inferIdentStart(cursorPos, text, treePath)
val end = inferIdentEnd(cursorPos, text)
val query = text.substring(start, end)
val prevIsDot =
if start - 1 >= 0 then text.charAt(start - 1) == '.' else false
val kind =
if prevIsDot then CompletionKind.Members
else if isImportOrExportSelect(cursorPos, treePath) then
CompletionKind.Members
else if query.nn.isEmpty then CompletionKind.Empty
else CompletionKind.Scope
val identEnd = inferIdentEnd(cursorPos, text)
val query = Completion.completionPrefix(adjustedPath, cursorPos)
val start = cursorPos.point - query.length()

CompletionPos(kind, start, end, query.nn, cursorPos, uri)
CompletionPos(start, identEnd, query.nn, cursorPos, uri)
end infer

/**
Expand All @@ -87,57 +76,6 @@ object CompletionPos:
(i, tabIndented)
end inferIndent

private def isImportOrExportSelect(
pos: SourcePosition,
path: List[Tree],
)(using Context): Boolean =
@tailrec
def loop(enclosing: List[Tree]): Boolean =
enclosing match
case head :: tl if !head.sourcePos.contains(pos) => loop(tl)
case (tree: (Import | Export)) :: _ =>
tree.selectors.exists(_.imported.sourcePos.contains(pos))
case _ => false

loop(path)


/**
* Returns the start offset of the identifier starting as the given offset position.
*/
private def inferIdentStart(
pos: SourcePosition,
text: String,
path: List[Tree]
)(using Context): Int =
def fallback: Int =
var i = pos.point - 1
while i >= 0 && Chars.isIdentifierPart(text.charAt(i)) do i -= 1
i + 1
def loop(enclosing: List[Tree]): Int =
enclosing match
case Nil => fallback
case head :: tl =>
if !head.sourcePos.contains(pos) then loop(tl)
else
head match
case i: Ident => i.sourcePos.point
case s: Select =>
if s.name.toTermName == nme.ERROR || s.span.exists && pos.span.point < s.span.point
then fallback
else s.span.point
case Import(_, sel) =>
sel
.collectFirst {
case ImportSelector(imported, renamed, _)
if imported.sourcePos.contains(pos) =>
imported.sourcePos.point
}
.getOrElse(fallback)
case _ => fallback
loop(path)
end inferIdentStart

/**
* Returns the end offset of the identifier starting as the given offset position.
*/
Expand Down
Loading

0 comments on commit 43003c7

Please sign in to comment.