From c9bef0a8de2d4a150405535e0fbdc630cc2fcc9e Mon Sep 17 00:00:00 2001 From: Alex Henning Johannessen Date: Thu, 22 Feb 2024 15:35:28 +0000 Subject: [PATCH] compiler: account for Scala 2 and 3 differences in generated vararg code --- .gitignore | 11 +++++ .../japi/twirl/compiler/TwirlCompiler.java | 2 + .../play/twirl/compiler/TwirlCompiler.scala | 48 ++++++++++++++++--- .../play/twirl/compiler/test/Helper.scala | 1 + .../play/twirl/compiler/test/Helper.scala | 12 ++++- .../main/scala/play/twirl/sbt/SbtTwirl.scala | 3 +- .../play/twirl/sbt/TemplateCompiler.scala | 4 +- 7 files changed, 70 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index 7a23fd8c..1bf33424 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,14 @@ project/plugins/project/ .bloop/ compiler/version.properties + +.vscode/ + +# Metals +.metals/ +.bloop/ +project/metals.sbt +metals.sbt + +# BSP +.bsp/* \ No newline at end of file diff --git a/compiler/src/main/java/play/japi/twirl/compiler/TwirlCompiler.java b/compiler/src/main/java/play/japi/twirl/compiler/TwirlCompiler.java index adafd572..e5cdbfdc 100644 --- a/compiler/src/main/java/play/japi/twirl/compiler/TwirlCompiler.java +++ b/compiler/src/main/java/play/japi/twirl/compiler/TwirlCompiler.java @@ -56,6 +56,7 @@ public static Optional compile( List constructorAnnotations, Codec codec, boolean inclusiveDot) { + String scalaVersion = play.twirl.compiler.BuildInfo$.MODULE$.scalaVersion(); Seq scalaAdditionalImports = toScalaSeq(additionalImports); Seq scalaConstructorAnnotations = toScalaSeq(constructorAnnotations); @@ -65,6 +66,7 @@ public static Optional compile( sourceDirectory, generatedDirectory, formatterType, + scalaVersion, scalaAdditionalImports, scalaConstructorAnnotations, codec, diff --git a/compiler/src/main/scala/play/twirl/compiler/TwirlCompiler.scala b/compiler/src/main/scala/play/twirl/compiler/TwirlCompiler.scala index ccac8c90..dc4d276d 100644 --- a/compiler/src/main/scala/play/twirl/compiler/TwirlCompiler.scala +++ b/compiler/src/main/scala/play/twirl/compiler/TwirlCompiler.scala @@ -162,6 +162,17 @@ case class GeneratedSourceVirtual(path: String) extends AbstractGeneratedSource object TwirlCompiler { + // For constants that depends on the Scala 2 or 3 mode. + private[compiler] class ScalaCompat(emitScala3Sources: Boolean) { + val varargSplicesSyntax: String = + if (emitScala3Sources) "*" else ": _*" + } + + private[compiler] object ScalaCompat { + def apply(scalaVersion: String): ScalaCompat = + new ScalaCompat(scalaVersion.startsWith("3.")) + } + def defaultImports(scalaVersion: String) = { val implicits = if (scalaVersion.startsWith("3.")) { Seq( @@ -195,6 +206,7 @@ object TwirlCompiler { sourceDirectory: File, generatedDirectory: File, formatterType: String, + scalaVersion: String, additionalImports: collection.Seq[String] = Nil, constructorAnnotations: collection.Seq[String] = Nil, codec: Codec = TwirlIO.defaultCodec, @@ -211,6 +223,7 @@ object TwirlCompiler { relativePath(source), resultType, formatterType, + scalaVersion, additionalImports, constructorAnnotations, inclusiveDot @@ -228,11 +241,13 @@ object TwirlCompiler { sourceDirectory: File, resultType: String, formatterType: String, + scalaVersion: String, additionalImports: collection.Seq[String] = Nil, constructorAnnotations: collection.Seq[String] = Nil, codec: Codec = TwirlIO.defaultCodec, inclusiveDot: Boolean = false ) = { + val (templateName, generatedSource) = generatedFileVirtual(source, sourceDirectory, inclusiveDot) val generated = parseAndGenerateCode( templateName, @@ -241,6 +256,7 @@ object TwirlCompiler { relativePath(source), resultType, formatterType, + scalaVersion, additionalImports, constructorAnnotations, inclusiveDot @@ -252,13 +268,14 @@ object TwirlCompiler { private def relativePath(file: File): String = new File(".").toURI.relativize(file.toURI).getPath - def parseAndGenerateCode( + private def parseAndGenerateCode( templateName: Array[String], content: Array[Byte], codec: Codec, relativePath: String, resultType: String, formatterType: String, + scalaVersion: String, additionalImports: collection.Seq[String], constructorAnnotations: collection.Seq[String], inclusiveDot: Boolean @@ -274,6 +291,7 @@ object TwirlCompiler { parsed, resultType, formatterType, + ScalaCompat(scalaVersion), additionalImports, constructorAnnotations ) @@ -446,10 +464,12 @@ object TwirlCompiler { root: Template, resultType: String, formatterType: String, + scalaCompat: ScalaCompat, additionalImports: collection.Seq[String], constructorAnnotations: collection.Seq[String] ): collection.Seq[Any] = { - val (renderCall, f, templateType) = TemplateAsFunctionCompiler.getFunctionMapping(root.params.str, resultType) + val (renderCall, f, templateType) = + TemplateAsFunctionCompiler.getFunctionMapping(root.params.str, resultType, scalaCompat) // Get the imports that we need to include, filtering out empty imports val imports: Seq[Any] = Seq(additionalImports.map(i => Seq("import ", i, "\n")), formatImports(root.topImports)) @@ -499,7 +519,7 @@ package """ :+ packageName :+ """ templateImports.map(_.replace("%format%", extension)) } - def generateFinalTemplate( + private[compiler] def generateFinalTemplate( relativePath: String, contents: Array[Byte], packageName: String, @@ -507,11 +527,21 @@ package """ :+ packageName :+ """ root: Template, resultType: String, formatterType: String, + scalaCompat: ScalaCompat, additionalImports: collection.Seq[String], constructorAnnotations: collection.Seq[String] ): String = { val generated = - generateCode(packageName, name, root, resultType, formatterType, additionalImports, constructorAnnotations) + generateCode( + packageName, + name, + root, + resultType, + formatterType, + scalaCompat, + additionalImports, + constructorAnnotations + ) Source.finalSource(relativePath, contents, generated, Hash(contents, additionalImports)) } @@ -531,7 +561,11 @@ package """ :+ packageName :+ """ } } - def getFunctionMapping(signature: String, returnType: String): (String, String, String) = { + private[compiler] def getFunctionMapping( + signature: String, + returnType: String, + sc: ScalaCompat + ): (String, String, String) = { val params: List[List[Term.Param]] = try { @@ -573,7 +607,7 @@ package """ :+ packageName :+ """ .map { p => p.name.toString + Option(p.decltpe.get.toString) .filter(_.endsWith("*")) - .map(_ => ".toIndexedSeq:_*") + .map(_ => s".toIndexedSeq${sc.varargSplicesSyntax}") .getOrElse("") } .mkString(",") + ")" @@ -601,7 +635,7 @@ package """ :+ packageName :+ """ .map { p => p.name.toString + Option(p.decltpe.get.toString) .filter(_.endsWith("*")) - .map(_ => ".toIndexedSeq:_*") + .map(_ => s".toIndexedSeq${sc.varargSplicesSyntax}") .getOrElse("") } .mkString(",") + ")" diff --git a/compiler/src/test/scala-2/play/twirl/compiler/test/Helper.scala b/compiler/src/test/scala-2/play/twirl/compiler/test/Helper.scala index 6448bf17..c61b53d0 100644 --- a/compiler/src/test/scala-2/play/twirl/compiler/test/Helper.scala +++ b/compiler/src/test/scala-2/play/twirl/compiler/test/Helper.scala @@ -99,6 +99,7 @@ object Helper { sourceDir, generatedDir, "play.twirl.api.HtmlFormat", + scalaVersion, additionalImports = TwirlCompiler.defaultImports(scalaVersion) ++ additionalImports ) diff --git a/compiler/src/test/scala-3/play/twirl/compiler/test/Helper.scala b/compiler/src/test/scala-3/play/twirl/compiler/test/Helper.scala index b53fe10b..63d65117 100644 --- a/compiler/src/test/scala-3/play/twirl/compiler/test/Helper.scala +++ b/compiler/src/test/scala-3/play/twirl/compiler/test/Helper.scala @@ -60,14 +60,19 @@ object Helper { ): CompiledTemplate[T] = { val scalaVersion = play.twirl.compiler.BuildInfo.scalaVersion val templateFile = new File(sourceDir, templateName) - val Some(generated) = twirlCompiler.compile( + val generatedOpt: Option[File] = twirlCompiler.compile( templateFile, sourceDir, generatedDir, "play.twirl.api.HtmlFormat", + scalaVersion, additionalImports = TwirlCompiler.defaultImports(scalaVersion) ++ additionalImports ) + val generated = generatedOpt.getOrElse { + throw new FileNotFoundException(s"Could not find generated file for $templateName") + } + val mapper = GeneratedSource(generated) val compilerArgs = Array( @@ -94,7 +99,10 @@ object Helper { class TestDriver(outDir: Path, compilerArgs: Array[String], path: Path) extends Driver { def compile(): Reporter = { - val Some((toCompile, rootCtx)) = setup(compilerArgs :+ path.toAbsolutePath.toString, initCtx.fresh) + val setupOpt = setup(compilerArgs :+ path.toAbsolutePath.toString, initCtx.fresh) + val (toCompile, rootCtx) = setupOpt.getOrElse { + throw new Exception("Failed to initialize compiler") + } val silentReporter = new ConsoleReporter.AbstractConsoleReporter { def printMessage(msg: String): Unit = { diff --git a/sbt-twirl/src/main/scala/play/twirl/sbt/SbtTwirl.scala b/sbt-twirl/src/main/scala/play/twirl/sbt/SbtTwirl.scala index b74123aa..cf022905 100644 --- a/sbt-twirl/src/main/scala/play/twirl/sbt/SbtTwirl.scala +++ b/sbt-twirl/src/main/scala/play/twirl/sbt/SbtTwirl.scala @@ -121,7 +121,8 @@ object SbtTwirl extends AutoPlugin { (compileTemplates / includeFilter).value, (compileTemplates / excludeFilter).value, Codec(sourceEncoding.value), - streams.value.log + streams.value.log, + scalaVersion.value ) } diff --git a/sbt-twirl/src/main/scala/play/twirl/sbt/TemplateCompiler.scala b/sbt-twirl/src/main/scala/play/twirl/sbt/TemplateCompiler.scala index beae4d51..26155385 100644 --- a/sbt-twirl/src/main/scala/play/twirl/sbt/TemplateCompiler.scala +++ b/sbt-twirl/src/main/scala/play/twirl/sbt/TemplateCompiler.scala @@ -20,7 +20,8 @@ object TemplateCompiler { includeFilter: FileFilter, excludeFilter: FileFilter, codec: Codec, - log: Logger + log: Logger, + scalaVersion: String ): Seq[File] = { try { syncGenerated(targetDirectory, codec) @@ -32,6 +33,7 @@ object TemplateCompiler { sourceDirectory, targetDirectory, format, + scalaVersion, imports, constructorAnnotations, codec,