Skip to content

Commit

Permalink
compiler: account for Scala 2 and 3 differences in generated vararg code
Browse files Browse the repository at this point in the history
  • Loading branch information
ahjohannessen committed Feb 22, 2024
1 parent ace8440 commit 125dff2
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 11 deletions.
11 changes: 11 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,14 @@ project/plugins/project/
.bloop/

compiler/version.properties

.vscode/

# Metals
.metals/
.bloop/
project/metals.sbt
metals.sbt

# BSP
.bsp/*
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public static Optional<File> compile(
List<String> constructorAnnotations,
Codec codec,
boolean inclusiveDot) {
String scalaVersion = play.twirl.compiler.BuildInfo$.MODULE$.scalaVersion();
Seq<String> scalaAdditionalImports = toScalaSeq(additionalImports);
Seq<String> scalaConstructorAnnotations = toScalaSeq(constructorAnnotations);

Expand All @@ -65,6 +66,7 @@ public static Optional<File> compile(
sourceDirectory,
generatedDirectory,
formatterType,
scalaVersion,
scalaAdditionalImports,
scalaConstructorAnnotations,
codec,
Expand Down
48 changes: 41 additions & 7 deletions compiler/src/main/scala/play/twirl/compiler/TwirlCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,17 @@ case class GeneratedSourceVirtual(path: String) extends AbstractGeneratedSource

object TwirlCompiler {

// For constants that depends on the Scala 2 or 3 mode.
private[compiler] class ScalaCompat(emitScala3Sources: Boolean) {
val varargSplicesSyntax: String =
if (emitScala3Sources) "*" else ": _*"
}

private[compiler] object ScalaCompat {
def apply(scalaVersion: String): ScalaCompat =
new ScalaCompat(scalaVersion.startsWith("3."))
}

def defaultImports(scalaVersion: String) = {
val implicits = if (scalaVersion.startsWith("3.")) {
Seq(
Expand Down Expand Up @@ -195,6 +206,7 @@ object TwirlCompiler {
sourceDirectory: File,
generatedDirectory: File,
formatterType: String,
scalaVersion: String,
additionalImports: collection.Seq[String] = Nil,
constructorAnnotations: collection.Seq[String] = Nil,
codec: Codec = TwirlIO.defaultCodec,
Expand All @@ -211,6 +223,7 @@ object TwirlCompiler {
relativePath(source),
resultType,
formatterType,
scalaVersion,
additionalImports,
constructorAnnotations,
inclusiveDot
Expand All @@ -228,11 +241,13 @@ object TwirlCompiler {
sourceDirectory: File,
resultType: String,
formatterType: String,
scalaVersion: String,
additionalImports: collection.Seq[String] = Nil,
constructorAnnotations: collection.Seq[String] = Nil,
codec: Codec = TwirlIO.defaultCodec,
inclusiveDot: Boolean = false
) = {

val (templateName, generatedSource) = generatedFileVirtual(source, sourceDirectory, inclusiveDot)
val generated = parseAndGenerateCode(
templateName,
Expand All @@ -241,6 +256,7 @@ object TwirlCompiler {
relativePath(source),
resultType,
formatterType,
scalaVersion,
additionalImports,
constructorAnnotations,
inclusiveDot
Expand All @@ -252,13 +268,14 @@ object TwirlCompiler {
private def relativePath(file: File): String =
new File(".").toURI.relativize(file.toURI).getPath

def parseAndGenerateCode(
private def parseAndGenerateCode(
templateName: Array[String],
content: Array[Byte],
codec: Codec,
relativePath: String,
resultType: String,
formatterType: String,
scalaVersion: String,
additionalImports: collection.Seq[String],
constructorAnnotations: collection.Seq[String],
inclusiveDot: Boolean
Expand All @@ -274,6 +291,7 @@ object TwirlCompiler {
parsed,
resultType,
formatterType,
ScalaCompat(scalaVersion),
additionalImports,
constructorAnnotations
)
Expand Down Expand Up @@ -446,10 +464,12 @@ object TwirlCompiler {
root: Template,
resultType: String,
formatterType: String,
scalaCompat: ScalaCompat,
additionalImports: collection.Seq[String],
constructorAnnotations: collection.Seq[String]
): collection.Seq[Any] = {
val (renderCall, f, templateType) = TemplateAsFunctionCompiler.getFunctionMapping(root.params.str, resultType)
val (renderCall, f, templateType) =
TemplateAsFunctionCompiler.getFunctionMapping(root.params.str, resultType, scalaCompat)

// Get the imports that we need to include, filtering out empty imports
val imports: Seq[Any] = Seq(additionalImports.map(i => Seq("import ", i, "\n")), formatImports(root.topImports))
Expand Down Expand Up @@ -499,19 +519,29 @@ package """ :+ packageName :+ """
templateImports.map(_.replace("%format%", extension))
}

def generateFinalTemplate(
private[compiler] def generateFinalTemplate(
relativePath: String,
contents: Array[Byte],
packageName: String,
name: String,
root: Template,
resultType: String,
formatterType: String,
scalaCompat: ScalaCompat,
additionalImports: collection.Seq[String],
constructorAnnotations: collection.Seq[String]
): String = {
val generated =
generateCode(packageName, name, root, resultType, formatterType, additionalImports, constructorAnnotations)
generateCode(
packageName,
name,
root,
resultType,
formatterType,
scalaCompat,
additionalImports,
constructorAnnotations
)

Source.finalSource(relativePath, contents, generated, Hash(contents, additionalImports))
}
Expand All @@ -531,7 +561,11 @@ package """ :+ packageName :+ """
}
}

def getFunctionMapping(signature: String, returnType: String): (String, String, String) = {
private[compiler] def getFunctionMapping(
signature: String,
returnType: String,
sc: ScalaCompat
): (String, String, String) = {

val params: List[List[Term.Param]] =
try {
Expand Down Expand Up @@ -573,7 +607,7 @@ package """ :+ packageName :+ """
.map { p =>
p.name.toString + Option(p.decltpe.get.toString)
.filter(_.endsWith("*"))
.map(_ => ".toIndexedSeq:_*")
.map(_ => s".toIndexedSeq${sc.varargSplicesSyntax}")
.getOrElse("")
}
.mkString(",") + ")"
Expand Down Expand Up @@ -601,7 +635,7 @@ package """ :+ packageName :+ """
.map { p =>
p.name.toString + Option(p.decltpe.get.toString)
.filter(_.endsWith("*"))
.map(_ => ".toIndexedSeq:_*")
.map(_ => s".toIndexedSeq${sc.varargSplicesSyntax}")
.getOrElse("")
}
.mkString(",") + ")"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ object Helper {
sourceDir,
generatedDir,
"play.twirl.api.HtmlFormat",
scalaVersion,
additionalImports = TwirlCompiler.defaultImports(scalaVersion) ++ additionalImports
)

Expand Down
12 changes: 10 additions & 2 deletions compiler/src/test/scala-3/play/twirl/compiler/test/Helper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,19 @@ object Helper {
): CompiledTemplate[T] = {
val scalaVersion = play.twirl.compiler.BuildInfo.scalaVersion
val templateFile = new File(sourceDir, templateName)
val Some(generated) = twirlCompiler.compile(
val generatedOpt: Option[File] = twirlCompiler.compile(
templateFile,
sourceDir,
generatedDir,
"play.twirl.api.HtmlFormat",
scalaVersion,
additionalImports = TwirlCompiler.defaultImports(scalaVersion) ++ additionalImports
)

val generated = generatedOpt.getOrElse {
throw new FileNotFoundException(s"Could not find generated file for $templateName")
}

val mapper = GeneratedSource(generated)

val compilerArgs = Array(
Expand All @@ -94,7 +99,10 @@ object Helper {

class TestDriver(outDir: Path, compilerArgs: Array[String], path: Path) extends Driver {
def compile(): Reporter = {
val Some((toCompile, rootCtx)) = setup(compilerArgs :+ path.toAbsolutePath.toString, initCtx.fresh)
val setupOpt = setup(compilerArgs :+ path.toAbsolutePath.toString, initCtx.fresh)
val (toCompile, rootCtx) = setupOpt.getOrElse {
throw new Exception("Failed to initialize compiler")
}

val silentReporter = new ConsoleReporter.AbstractConsoleReporter {
def printMessage(msg: String): Unit = {
Expand Down
3 changes: 2 additions & 1 deletion sbt-twirl/src/main/scala/play/twirl/sbt/SbtTwirl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ object SbtTwirl extends AutoPlugin {
(compileTemplates / includeFilter).value,
(compileTemplates / excludeFilter).value,
Codec(sourceEncoding.value),
streams.value.log
streams.value.log,
scalaVersion.value
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ object TemplateCompiler {
includeFilter: FileFilter,
excludeFilter: FileFilter,
codec: Codec,
log: Logger
log: Logger,
scalaVersion: String
): Seq[File] = {
try {
syncGenerated(targetDirectory, codec)
Expand All @@ -32,6 +33,7 @@ object TemplateCompiler {
sourceDirectory,
targetDirectory,
format,
scalaVersion,
imports,
constructorAnnotations,
codec,
Expand Down

0 comments on commit 125dff2

Please sign in to comment.