Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Scala 3 support #56

Merged
merged 3 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ jobs:
- 2.13.6
- 2.13.7
- 2.13.8
- 3.3.3
- 3.4.1
java:
- 1.8
- 1.11
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
**/target
.*
!.gitignore
49 changes: 42 additions & 7 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,32 @@ crossScalaVersions := List(
"2.13.5",
"2.13.6",
"2.13.7",
"2.13.8"
"2.13.8",
"3.3.3",
"3.4.1"
)

crossVersion := CrossVersion.full
scalacOptions ++= Seq("-deprecation")

libraryDependencies ++= Seq(
scalaOrganization.value % "scala-compiler" % scalaVersion.value,
"org.scalatest" %% "scalatest" % "3.0.8" % Test,
scalaVersion.value match {
case version if version.startsWith("3.") =>
"org.scala-lang" %% "scala3-compiler" % scalaVersion.value
case _ =>
scalaOrganization.value % "scala-compiler" % scalaVersion.value
},
"org.scalatest" %% "scalatest" % "3.2.18" % Test,
"commons-io" % "commons-io" % "2.6" % Test
)

organization := "com.github.cb372"
homepage := Some(url("https://github.com/cb372/scala-typed-holes"))
licenses := Seq("Apache License, Version 2.0" -> url("http://www.apache.org/licenses/LICENSE-2.0.html"))
licenses := Seq(
"Apache License, Version 2.0" -> url(
"http://www.apache.org/licenses/LICENSE-2.0.html"
)
)
developers := List(
Developer(
"cb372",
Expand All @@ -39,17 +51,40 @@ developers := List(
Test / fork := true
Test / javaOptions ++= {
val jar = (Compile / packageBin).value
val scalacClasspath = scalaInstance.value.allJars.mkString(java.io.File.pathSeparator)
val scalacClasspath =
scalaInstance.value.allJars.mkString(java.io.File.pathSeparator)
Seq(
s"-Dplugin.jar=${jar.getAbsolutePath}",
s"-Dscalac.classpath=$scalacClasspath"
)
}

val `scala-typed-holes` = project.in(file("."))
val `scala-typed-holes` =
project
.in(file("."))
.enablePlugins(BuildInfoPlugin)

val `scala-typed-holes-3` = // just for IDE support
project
.in(file("scala-typed-holes-3"))
.settings(
scalaVersion := "3.4.1",
crossScalaVersions := Nil,
Compile / unmanagedSourceDirectories := Seq(),
Compile / unmanagedSourceDirectories += (ThisBuild / baseDirectory).value / "src" / "main" / "scala",
Compile / unmanagedSourceDirectories += (ThisBuild / baseDirectory).value / "src" / "main" / "scala-3",
moduleName := "scala-typed-holes-3",
target := (ThisBuild / baseDirectory).value / "target" / "target3",
publish / skip := true,
libraryDependencies ++= Seq(
"org.scala-lang" %% "scala3-compiler" % "3.4.1"
)
)

val docs = project
.in(file("generated-docs")) // important: it must not be the actual directory name, i.e. docs/
.in(
file("generated-docs")
) // important: it must not be the actual directory name, i.e. docs/
.settings(
scalaVersion := "2.12.10",
crossScalaVersions := Nil,
Expand Down
2 changes: 1 addition & 1 deletion project/build.properties
Original file line number Diff line number Diff line change
@@ -1 +1 @@
sbt.version=1.5.2
sbt.version=1.9.9
1 change: 1 addition & 0 deletions project/plugins.sbt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
addSbtPlugin("org.scalameta" % "sbt-mdoc" % "1.3.1")
addSbtPlugin("com.github.sbt" % "sbt-ci-release" % "1.5.10")
addSbtPlugin("com.eed3si9n" % "sbt-buildinfo" % "0.11.0")
1 change: 1 addition & 0 deletions src/main/resources/plugin.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pluginClass=holes.TypedHolesPlugin
47 changes: 47 additions & 0 deletions src/main/scala-3/holes/NamedHolesPhase.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package holes

import dotty.tools.dotc.plugins.PluginPhase
import dotty.tools.dotc.parsing.Parser
import dotty.tools.dotc.typer.TyperPhase
import dotty.tools.dotc.core.Contexts.Context
import dotty.tools.dotc.ast.tpd
import dotty.tools.dotc.ast.untpd
import dotty.tools.dotc.core.Names.Name
import dotty.tools.dotc.util.Property
import dotty.tools.dotc.core.Names.TermName
import dotty.tools.dotc.core.StdNames
import dotty.tools.dotc.report
import dotty.tools.dotc.core.Contexts.ctx
import dotty.tools.dotc.core.Contexts.atPhase
import dotty.tools.dotc.transform.MegaPhase
import dotty.tools.dotc.ast.untpd.TreeTraverser
import dotty.tools.dotc.ast.Trees.CaseDef
import dotty.tools.dotc.ast.untpd.UntypedTreeMap

class NamedHolesPhase extends PluginPhase:
override def phaseName: String = "named-holes"

override val runsAfter: Set[String] = Set(Parser.name)
override val runsBefore: Set[String] = Set(TyperPhase.name)

override def prepareForUnit(tree: tpd.Tree)(using Context): Context =
val namedHolesTreeMap = new UntypedTreeMap {
override def transform(tree: tpd.Tree)(using Context): tpd.Tree =
tree match
case tpd.Ident(NamedHole(name)) =>
val copied = cpy.Ident(tree)(StdNames.nme.???)
copied.putAttachment(NamedHole.NamedHole, HoleName(name))
copied
case _ =>
super.transform(tree)
}
ctx.compilationUnit.untpdTree =
namedHolesTreeMap.transform(ctx.compilationUnit.untpdTree)
ctx

object NamedHole:
val NamedHole: Property.Key[HoleName] = Property.StickyKey()
val pattern = "^__([a-zA-Z0-9_]+)$".r

def unapply(name: Name): Option[String] =
pattern.unapplySeq(name.decode.toString).flatMap(_.headOption)
127 changes: 127 additions & 0 deletions src/main/scala-3/holes/TypedHolesPhase.scala
cb372 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package holes

import dotty.tools.dotc.plugins.PluginPhase
import dotty.tools.dotc.typer.TyperPhase
import dotty.tools.dotc.core.Contexts.Context
import dotty.tools.dotc.ast.tpd
import dotty.tools.dotc.core.Symbols.defn
import org.jline.terminal.impl.jna.freebsd.CLibrary.termios
import dotty.tools.dotc.core.Types.Type
import dotty.tools.dotc.report
import dotty.tools.dotc.interfaces.SourcePosition
import dotty.tools.dotc.core.Names.TermName
import dotty.tools.dotc.core.Periods.PhaseId
import dotty.tools.dotc.transform.CheckUnused
import dotty.tools.dotc.transform.PostTyper
import scala.collection.mutable
import dotty.tools.dotc.ast.Trees.DefDef
import dotty.tools.dotc.ast.Trees.If
import dotty.tools.dotc.printing.RefinedPrinter
import dotty.tools.dotc.reporting.Diagnostic.Info
kasiaMarek marked this conversation as resolved.
Show resolved Hide resolved

class TypedHolesPhase(logLevel: LogLevel) extends PluginPhase:
override def phaseName: String = "typed-holes"
override val runsAfter: Set[String] = Set(TyperPhase.name)
override val runsBefore: Set[String] = Set(PostTyper.name)

private val bindings: mutable.Stack[Map[TermName, Binding]] =
new mutable.Stack
kasiaMarek marked this conversation as resolved.
Show resolved Hide resolved

private val defaultWidth = 1000

override def prepareForDefDef(tree: tpd.DefDef)(using Context): Context =
bindings.push(
tree.termParamss.flatten
.map(param => (param.name, Binding(param.tpt.tpe, param.sourcePos)))
.toMap
)
summon[Context]
kasiaMarek marked this conversation as resolved.
Show resolved Hide resolved

override def transformValDef(tree: tpd.ValDef)(using Context): tpd.Tree =
logHole(tree.rhs, tree.tpt.tpe)
tree

override def transformDefDef(tree: tpd.DefDef)(using Context): tpd.Tree =
logHole(tree.rhs, tree.tpt.tpe)
bindings.pop()
tree

override def transformApply(tree: tpd.Apply)(using Context): tpd.Tree =
tree match
case tpd.Apply(fun, args) =>
val paramIndex = countIndex(fun, 0)
args
.zip(fun.symbol.paramSymss(paramIndex))
.foreach:
case (arg, param) => logHole(arg, param.info)
tree
kasiaMarek marked this conversation as resolved.
Show resolved Hide resolved

private def countIndex(fun: tpd.Tree, index: Int)(using Context): Int =

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: countIndex sounds a bit generic to me, maybe paramListIndex might be slightly more explicit about what the function does. I would also add a default value to allow calling it without the magic number in transformApply.

fun match
case tpd.Apply(f, _) if f.symbol == fun.symbol => countIndex(f, index + 1)
case _ => index

private def logHole(holeTree: tpd.Tree, tpe: => Type)(using Context): Unit =
holeTree match
case Hole(holeInRhs) => log(holeInRhs, tpe.widen)
case tpd.Block(_, rhs) => logHole(rhs, tpe)
case tpd.If(_, thenp, elsep) =>
logHole(thenp, tpe)
logHole(elsep, tpe)
case tpd.Match(_, caseDefs) => caseDefs.foreach(logHole(_, tpe))
case tpd.CaseDef(_, _, tree) => logHole(tree, tpe)
case _ =>

private def isNothing(tpe: Type)(using Context): Boolean =
kasiaMarek marked this conversation as resolved.
Show resolved Hide resolved
tpe.widen == defn.NothingType

private def collectRelevantBindings(using
ctx: Context
): Map[TermName, Binding] =
bindings.foldLeft(Map.empty[TermName, Binding]) { case (acc, level) =>
level ++ acc
}

private def log(holeTree: tpd.Tree, tpe: Type)(using Context): Unit = {
val printer = RefinedPrinter(summon[Context])
kasiaMarek marked this conversation as resolved.
Show resolved Hide resolved
def printType(tpe: Type) =
printer.toText(tpe).mkString(defaultWidth, false)

val relevantBindingsMessages =
collectRelevantBindings.toList
.sortBy(_._1.toString)
.map:
case (boundName, Binding(boundType, bindingPos)) =>
s" $boundName: ${printType(boundType)} (bound at ${posSummary(bindingPos)})"
.mkString("\n")

val holeName = holeTree.getAttachment(NamedHole.NamedHole)
val holeNameMsg = holeName.fold("")(x => s"'${x.name}' ")
val message =
if (!relevantBindingsMessages.isEmpty)
s"""
|Found hole ${holeNameMsg}with type: ${printType(tpe)}
|Relevant bindings include
|$relevantBindingsMessages
""".stripMargin
else
s"Found hole ${holeNameMsg}with type: ${printType(tpe)}"
logLevel match
case LogLevel.Info =>
summon[Context].reporter.report(new Info(message, holeTree.sourcePos))
case LogLevel.Warn => report.warning(message, holeTree.sourcePos)
case LogLevel.Error => report.error(message, holeTree.sourcePos)
}

private def posSummary(pos: SourcePosition)(using Context): String =
s"${pos.source().name()}:${pos.line}:${pos.column}"
kasiaMarek marked this conversation as resolved.
Show resolved Hide resolved

object Hole:
def unapply(tree: tpd.Tree)(using Context): Option[tpd.Tree] =
tree match
case _ if tree.symbol == defn.Predef_undefined => Some(tree)
case tpd.Block(_, expr) if tree.symbol == defn.Predef_undefined =>
kasiaMarek marked this conversation as resolved.
Show resolved Hide resolved
Some(expr)
case _ => None

case class Binding(tpe: Type, pos: SourcePosition)
28 changes: 28 additions & 0 deletions src/main/scala-3/holes/TypedHolesPlugin.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package holes
import dotty.tools.dotc.plugins.{PluginPhase, StandardPlugin}
import scala.util.control.ControlThrowable
import dotty.tools.dotc.report

class TypedHolesPlugin extends StandardPlugin:
val name: String = "typed-holes"
override def description: String =
"Treat use of ??? as a hole and give a useful warning about it"

override def init(options: List[String]): List[PluginPhase] =
val logLevel =
options
.flatMap: option =>
if option.startsWith("log-level:") then
option.substring("log-level:".length).toLowerCase match
case "info" => Some(LogLevel.Info)
case "warn" => Some(LogLevel.Warn)
case "error" => Some(LogLevel.Error)
case _ => None
else None
.headOption
.getOrElse(LogLevel.Warn)

List(
new NamedHolesPhase,
TypedHolesPhase(logLevel)
)
28 changes: 28 additions & 0 deletions src/test/resources/arguments/expected-3.3.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
-- Info: src/test/resources/arguments/input.scala:9:4 --------------------------
9 | f(???, if(1 == 2) ??? else ???)
| ^^^
| Found hole with type: Int
-- Info: src/test/resources/arguments/input.scala:9:20 -------------------------
9 | f(???, if(1 == 2) ??? else ???)
| ^^^
| Found hole with type: String
-- Info: src/test/resources/arguments/input.scala:9:29 -------------------------
9 | f(???, if(1 == 2) ??? else ???)
| ^^^
| Found hole with type: String
-- Info: src/test/resources/arguments/input.scala:12:8 -------------------------
12 | ff(1)(???){ ??? }
| ^^^
| Found hole with type: String
-- Info: src/test/resources/arguments/input.scala:12:14 ------------------------
12 | ff(1)(???){ ??? }
| ^^^
| Found hole with type: a.O
-- Info: src/test/resources/arguments/input.scala:16:4 -------------------------
16 | m(???)
| ^^^
|
| Found hole with type: Int
| Relevant bindings include
| evidence$1: a.A.OpaqueInt (bound at input.scala:15:4)
|
28 changes: 28 additions & 0 deletions src/test/resources/arguments/expected-3.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
-- Info: src/test/resources/arguments/input.scala:9:4 --------------------------
9 | f(???, if(1 == 2) ??? else ???)
| ^^^
| Found hole with type: Int
-- Info: src/test/resources/arguments/input.scala:9:20 -------------------------
9 | f(???, if(1 == 2) ??? else ???)
| ^^^
| Found hole with type: String
-- Info: src/test/resources/arguments/input.scala:9:29 -------------------------
9 | f(???, if(1 == 2) ??? else ???)
| ^^^
| Found hole with type: String
-- Info: src/test/resources/arguments/input.scala:12:8 -------------------------
12 | ff(1)(???){ ??? }
| ^^^
| Found hole with type: String
-- Info: src/test/resources/arguments/input.scala:12:14 ------------------------
12 | ff(1)(???){ ??? }
| ^^^
| Found hole with type: a.O
-- Info: src/test/resources/arguments/input.scala:16:4 -------------------------
16 | m(???)
| ^^^
|
| Found hole with type: Int
| Relevant bindings include
| contextual$1: a.A.OpaqueInt (bound at input.scala:15:4)
cb372 marked this conversation as resolved.
Show resolved Hide resolved
|
17 changes: 17 additions & 0 deletions src/test/resources/arguments/input.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package a

object A:
opaque type OpaqueInt = Int
given OpaqueInt = 1

object O:
def f(i: Int, s: String) = i
f(???, if(1 == 2) ??? else ???)

def ff(i: Int)(s: String)(u: O.type) = 1
ff(1)(???){ ??? }

import A.*
def m(j: OpaqueInt ?=> Int): Int = j
m(???)

Loading