Skip to content

Commit

Permalink
Add Scala 3 support
Browse files Browse the repository at this point in the history
  • Loading branch information
kasiaMarek committed May 6, 2024
1 parent 207f724 commit 1561f64
Show file tree
Hide file tree
Showing 28 changed files with 557 additions and 30 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ jobs:
- 2.13.6
- 2.13.7
- 2.13.8
- 3.3.3
- 3.4.1
java:
- 1.8
- 1.11
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
**/target
.*
!.gitignore
49 changes: 42 additions & 7 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,32 @@ crossScalaVersions := List(
"2.13.5",
"2.13.6",
"2.13.7",
"2.13.8"
"2.13.8",
"3.3.3",
"3.4.1"
)

crossVersion := CrossVersion.full
scalacOptions ++= Seq("-deprecation")

libraryDependencies ++= Seq(
scalaOrganization.value % "scala-compiler" % scalaVersion.value,
"org.scalatest" %% "scalatest" % "3.0.8" % Test,
scalaVersion.value match {
case version if version.startsWith("3.") =>
"org.scala-lang" %% "scala3-compiler" % scalaVersion.value
case _ =>
scalaOrganization.value % "scala-compiler" % scalaVersion.value
},
"org.scalatest" %% "scalatest" % "3.2.18" % Test,
"commons-io" % "commons-io" % "2.6" % Test
)

organization := "com.github.cb372"
homepage := Some(url("https://github.com/cb372/scala-typed-holes"))
licenses := Seq("Apache License, Version 2.0" -> url("http://www.apache.org/licenses/LICENSE-2.0.html"))
licenses := Seq(
"Apache License, Version 2.0" -> url(
"http://www.apache.org/licenses/LICENSE-2.0.html"
)
)
developers := List(
Developer(
"cb372",
Expand All @@ -39,17 +51,40 @@ developers := List(
Test / fork := true
Test / javaOptions ++= {
val jar = (Compile / packageBin).value
val scalacClasspath = scalaInstance.value.allJars.mkString(java.io.File.pathSeparator)
val scalacClasspath =
scalaInstance.value.allJars.mkString(java.io.File.pathSeparator)
Seq(
s"-Dplugin.jar=${jar.getAbsolutePath}",
s"-Dscalac.classpath=$scalacClasspath"
)
}

val `scala-typed-holes` = project.in(file("."))
val `scala-typed-holes` =
project
.in(file("."))
.enablePlugins(BuildInfoPlugin)

val `scala-typed-holes-3` = // just for IDE support
project
.in(file("scala-typed-holes-3"))
.settings(
scalaVersion := "3.4.1",
crossScalaVersions := Nil,
Compile / unmanagedSourceDirectories := Seq(),
Compile / unmanagedSourceDirectories += (ThisBuild / baseDirectory).value / "src" / "main" / "scala",
Compile / unmanagedSourceDirectories += (ThisBuild / baseDirectory).value / "src" / "main" / "scala-3",
moduleName := "scala-typed-holes-3",
target := (ThisBuild / baseDirectory).value / "target" / "target3",
publish / skip := true,
libraryDependencies ++= Seq(
"org.scala-lang" %% "scala3-compiler" % "3.4.1"
)
)

val docs = project
.in(file("generated-docs")) // important: it must not be the actual directory name, i.e. docs/
.in(
file("generated-docs")
) // important: it must not be the actual directory name, i.e. docs/
.settings(
scalaVersion := "2.12.10",
crossScalaVersions := Nil,
Expand Down
2 changes: 1 addition & 1 deletion project/build.properties
Original file line number Diff line number Diff line change
@@ -1 +1 @@
sbt.version=1.5.2
sbt.version=1.9.9
1 change: 1 addition & 0 deletions project/plugins.sbt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
addSbtPlugin("org.scalameta" % "sbt-mdoc" % "1.3.1")
addSbtPlugin("com.github.sbt" % "sbt-ci-release" % "1.5.10")
addSbtPlugin("com.eed3si9n" % "sbt-buildinfo" % "0.11.0")
1 change: 1 addition & 0 deletions src/main/resources/plugin.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pluginClass=holes.TypedHolesPlugin
File renamed without changes.
47 changes: 47 additions & 0 deletions src/main/scala-3/holes/NamedHolesPhase.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package holes

import dotty.tools.dotc.plugins.PluginPhase
import dotty.tools.dotc.parsing.Parser
import dotty.tools.dotc.typer.TyperPhase
import dotty.tools.dotc.core.Contexts.Context
import dotty.tools.dotc.ast.tpd
import dotty.tools.dotc.ast.untpd
import dotty.tools.dotc.core.Names.Name
import dotty.tools.dotc.util.Property
import dotty.tools.dotc.core.Names.TermName
import dotty.tools.dotc.core.StdNames
import dotty.tools.dotc.report
import dotty.tools.dotc.core.Contexts.ctx
import dotty.tools.dotc.core.Contexts.atPhase
import dotty.tools.dotc.transform.MegaPhase
import dotty.tools.dotc.ast.untpd.TreeTraverser
import dotty.tools.dotc.ast.Trees.CaseDef
import dotty.tools.dotc.ast.untpd.UntypedTreeMap

class NamedHolesPhase extends PluginPhase:
override def phaseName: String = "named-holes"

override val runsAfter: Set[String] = Set(Parser.name)
override val runsBefore: Set[String] = Set(TyperPhase.name)

override def prepareForUnit(tree: tpd.Tree)(using Context): Context =
val namedHolesTreeMap = new UntypedTreeMap {
override def transform(tree: tpd.Tree)(using Context): tpd.Tree =
tree match
case tpd.Ident(NamedHole(name)) =>
val copied = cpy.Ident(tree)(StdNames.nme.???)
copied.putAttachment(NamedHole.NamedHole, HoleName(name))
copied
case _ =>
super.transform(tree)
}
ctx.compilationUnit.untpdTree =
namedHolesTreeMap.transform(ctx.compilationUnit.untpdTree)
ctx

object NamedHole:
val NamedHole: Property.Key[HoleName] = Property.StickyKey()
val pattern = "^__([a-zA-Z0-9_]+)$".r

def unapply(name: Name): Option[String] =
pattern.unapplySeq(name.decode.toString).flatMap(_.headOption)
127 changes: 127 additions & 0 deletions src/main/scala-3/holes/TypedHolesPhase.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package holes

import dotty.tools.dotc.plugins.PluginPhase
import dotty.tools.dotc.typer.TyperPhase
import dotty.tools.dotc.core.Contexts.Context
import dotty.tools.dotc.ast.tpd
import dotty.tools.dotc.core.Symbols.defn
import org.jline.terminal.impl.jna.freebsd.CLibrary.termios
import dotty.tools.dotc.core.Types.Type
import dotty.tools.dotc.report
import dotty.tools.dotc.interfaces.SourcePosition
import dotty.tools.dotc.core.Names.TermName
import dotty.tools.dotc.core.Periods.PhaseId
import dotty.tools.dotc.transform.CheckUnused
import dotty.tools.dotc.transform.PostTyper
import scala.collection.mutable
import dotty.tools.dotc.ast.Trees.DefDef
import dotty.tools.dotc.ast.Trees.If
import dotty.tools.dotc.printing.RefinedPrinter
import dotty.tools.dotc.reporting.Diagnostic.Info

class TypedHolesPhase(logLevel: LogLevel) extends PluginPhase:
override def phaseName: String = "typed-holes"
override val runsAfter: Set[String] = Set(TyperPhase.name)
override val runsBefore: Set[String] = Set(PostTyper.name)

private val bindings: mutable.Stack[Map[TermName, Binding]] =
new mutable.Stack

private val defaultWidth = 1000

override def prepareForDefDef(tree: tpd.DefDef)(using Context): Context =
bindings.push(
tree.termParamss.flatten
.map(param => (param.name, Binding(param.tpt.tpe, param.sourcePos)))
.toMap
)
summon[Context]

override def transformValDef(tree: tpd.ValDef)(using Context): tpd.Tree =
logHole(tree.rhs, tree.tpt.tpe)
tree

override def transformDefDef(tree: tpd.DefDef)(using Context): tpd.Tree =
logHole(tree.rhs, tree.tpt.tpe)
bindings.pop()
tree

override def transformApply(tree: tpd.Apply)(using Context): tpd.Tree =
tree match
case tpd.Apply(fun, args) =>
val paramIndex = countIndex(fun, 0)
args
.zip(fun.symbol.paramSymss(paramIndex))
.foreach:
case (arg, param) => logHole(arg, param.info)
tree

private def countIndex(fun: tpd.Tree, index: Int)(using Context): Int =
fun match
case tpd.Apply(f, _) if f.symbol == fun.symbol => countIndex(f, index + 1)
case _ => index

private def logHole(holeTree: tpd.Tree, tpe: => Type)(using Context): Unit =
holeTree match
case Hole(holeInRhs) => log(holeInRhs, tpe.widen)
case tpd.Block(_, rhs) => logHole(rhs, tpe)
case tpd.If(_, thenp, elsep) =>
logHole(thenp, tpe)
logHole(elsep, tpe)
case tpd.Match(_, caseDefs) => caseDefs.foreach(logHole(_, tpe))
case tpd.CaseDef(_, _, tree) => logHole(tree, tpe)
case _ =>

private def isNothing(tpe: Type)(using Context): Boolean =
tpe.widen == defn.NothingType

private def collectRelevantBindings(using
ctx: Context
): Map[TermName, Binding] =
bindings.foldLeft(Map.empty[TermName, Binding]) { case (acc, level) =>
level ++ acc
}

private def log(holeTree: tpd.Tree, tpe: Type)(using Context): Unit = {
val printer = RefinedPrinter(summon[Context])
def printType(tpe: Type) =
printer.toText(tpe).mkString(defaultWidth, false)

val relevantBindingsMessages =
collectRelevantBindings.toList
.sortBy(_._1.toString)
.map:
case (boundName, Binding(boundType, bindingPos)) =>
s" $boundName: ${printType(boundType)} (bound at ${posSummary(bindingPos)})"
.mkString("\n")

val holeName = holeTree.getAttachment(NamedHole.NamedHole)
val holeNameMsg = holeName.fold("")(x => s"'${x.name}' ")
val message =
if (!relevantBindingsMessages.isEmpty)
s"""
|Found hole ${holeNameMsg}with type: ${printType(tpe)}
|Relevant bindings include
|$relevantBindingsMessages
""".stripMargin
else
s"Found hole ${holeNameMsg}with type: ${printType(tpe)}"
logLevel match
case LogLevel.Info =>
summon[Context].reporter.report(new Info(message, holeTree.sourcePos))
case LogLevel.Warn => report.warning(message, holeTree.sourcePos)
case LogLevel.Error => report.error(message, holeTree.sourcePos)
}

private def posSummary(pos: SourcePosition)(using Context): String =
s"${pos.source().name()}:${pos.line}:${pos.column}"

object Hole:
def unapply(tree: tpd.Tree)(using Context): Option[tpd.Tree] =
tree match
case _ if tree.symbol == defn.Predef_undefined => Some(tree)
case tpd.Block(_, expr) if tree.symbol == defn.Predef_undefined =>
Some(expr)
case _ => None

case class Binding(tpe: Type, pos: SourcePosition)
28 changes: 28 additions & 0 deletions src/main/scala-3/holes/TypedHolesPlugin.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package holes
import dotty.tools.dotc.plugins.{PluginPhase, StandardPlugin}
import scala.util.control.ControlThrowable
import dotty.tools.dotc.report

class TypedHolesPlugin extends StandardPlugin:
val name: String = "typed-holes"
override def description: String =
"Treat use of ??? as a hole and give a useful warning about it"

override def init(options: List[String]): List[PluginPhase] =
val logLevel =
options
.flatMap: option =>
if option.startsWith("log-level:") then
option.substring("log-level:".length).toLowerCase match
case "info" => Some(LogLevel.Info)
case "warn" => Some(LogLevel.Warn)
case "error" => Some(LogLevel.Error)
case _ => None
else None
.headOption
.getOrElse(LogLevel.Warn)

List(
new NamedHolesPhase,
TypedHolesPhase(logLevel)
)
28 changes: 28 additions & 0 deletions src/test/resources/arguments/expected-3.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
-- Info: src/test/resources/arguments/input.scala:9:4 --------------------------
9 | f(???, if(1 == 2) ??? else ???)
| ^^^
| Found hole with type: Int
-- Info: src/test/resources/arguments/input.scala:9:20 -------------------------
9 | f(???, if(1 == 2) ??? else ???)
| ^^^
| Found hole with type: String
-- Info: src/test/resources/arguments/input.scala:9:29 -------------------------
9 | f(???, if(1 == 2) ??? else ???)
| ^^^
| Found hole with type: String
-- Info: src/test/resources/arguments/input.scala:12:8 -------------------------
12 | ff(1)(???){ ??? }
| ^^^
| Found hole with type: String
-- Info: src/test/resources/arguments/input.scala:12:14 ------------------------
12 | ff(1)(???){ ??? }
| ^^^
| Found hole with type: a.O
-- Info: src/test/resources/arguments/input.scala:16:4 -------------------------
16 | m(???)
| ^^^
|
| Found hole with type: Int
| Relevant bindings include
| contextual$1: a.A.OpaqueInt (bound at input.scala:15:4)
|
17 changes: 17 additions & 0 deletions src/test/resources/arguments/input.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package a

object A:
opaque type OpaqueInt = Int
given OpaqueInt = 1

object O:
def f(i: Int, s: String) = i
f(???, if(1 == 2) ??? else ???)

def ff(i: Int)(s: String)(u: O.type) = 1
ff(1)(???){ ??? }

import A.*
def m(j: OpaqueInt ?=> Int): Int = j
m(???)

27 changes: 27 additions & 0 deletions src/test/resources/if-else/expected-3.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
-- Info: src/test/resources/if-else/input.scala:7:6 ----------------------------
7 | ???
| ^^^
|
| Found hole with type: Boolean
| Relevant bindings include
| x: Int (bound at input.scala:4:8)
| y: String (bound at input.scala:4:16)
|
-- Info: src/test/resources/if-else/input.scala:10:6 ---------------------------
10 | ???
| ^^^
|
| Found hole with type: Boolean
| Relevant bindings include
| x: Int (bound at input.scala:4:8)
| y: String (bound at input.scala:4:16)
|
-- Info: src/test/resources/if-else/input.scala:15:23 --------------------------
15 | if (y.length == x) ???
| ^^^
|
| Found hole with type: Boolean
| Relevant bindings include
| x: Int (bound at input.scala:13:8)
| y: String (bound at input.scala:13:16)
|
16 changes: 16 additions & 0 deletions src/test/resources/named-holes/expected-3.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
-- Info: src/test/resources/named-holes/input.scala:12:25 ----------------------
12 | case Left(error) => __left
| ^^^^^^
|
| Found hole 'left' with type: Option[Result]
| Relevant bindings include
| args: Array[String] (bound at input.scala:10:12)
|
-- Info: src/test/resources/named-holes/input.scala:13:25 ----------------------
13 | case Right(x) => __right
| ^^^^^^^
|
| Found hole 'right' with type: Option[Result]
| Relevant bindings include
| args: Array[String] (bound at input.scala:10:12)
|
Loading

0 comments on commit 1561f64

Please sign in to comment.