Skip to content

Commit

Permalink
feat: add support for Scala 2 #14
Browse files Browse the repository at this point in the history
* feat: add support for Scala 2 #14

* fix: nopenopenope, no braceless syntax :/

* fix: wip, traverse & log redacted fields

* fix: y u no patch :/

* fix: it works!

* fix: added test for curried case class ctor

* fix: crosscompile for scala 2.12.x and updated 3.4.x

* chore: updated README.md

* chore: bump version
  • Loading branch information
polentino authored May 12, 2024
1 parent f432114 commit 288a99d
Show file tree
Hide file tree
Showing 11 changed files with 298 additions and 21 deletions.
30 changes: 28 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ in your `build.sbt` file, add the following lines

```scala 3
val redactedVersion = // use latest version of the library
resolvers += DefaultMavenRepository,
resolvers += DefaultMavenRepository
,
libraryDependencies ++= Seq(
"io.github.polentino" %% "redacted" % redactedVersion cross CrossVersion.full,
compilerPlugin("io.github.polentino" %% "redacted-plugin" % redactedVersion cross CrossVersion.full)
Expand Down Expand Up @@ -131,6 +132,30 @@ println(wrapper)
will print
> Wrapper(id-1,***)
### Note on curried case classes

While it is possible to write something like

```scala 3
case class Curried(id: String, @redacted name: String)(@redacted email: String)
```

the `toString` method that Scala compiler generates by default will print only the parameters in the primary
constructor, meaning that

```scala 3
val c = Curried(0, "Berfu")("berfu@gmail.com")
println(c)
```

will display

```scala 3
Curried(0,Berfu)
```

Therefore, the same behavior is being kept in the customized `toString` implementation.

## How it works

Given a case class with at least one field annotated with `@redacted`, i.e.
Expand All @@ -157,7 +182,8 @@ implementation by selectively returning either the `***` string, or the value of

```scala 3
def toString(): String =
"<class name>(" + this.<field not redacted> + "," + "***" + ... + ")"
"<class name>(" + this.< field not redacted > + "," + "***" +
...+")"
```

## Improvements
Expand Down
15 changes: 11 additions & 4 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ val scalaCheckVersion = "3.2.17.0"

// all LTS versions & latest minor ones
val supportedScalaVersions = List(
"2.12.19",
"2.13.13",
"3.1.3",
"3.2.2",
"3.3.0",
"3.3.1",
"3.3.3",
"3.4.0"
"3.4.1"
)

inThisBuild(
Expand Down Expand Up @@ -64,7 +66,11 @@ lazy val redactedCompilerPlugin = (project in file("plugin"))
.settings(name := "redacted-plugin")
.settings(
crossCompileSettings,
libraryDependencies += "org.scala-lang" %% "scala3-compiler" % scalaVersion.value
libraryDependencies += (CrossVersion.partialVersion(scalaVersion.value) match {
case Some((3, _)) => "org.scala-lang" %% "scala3-compiler" % scalaVersion.value
case Some((2, _)) => "org.scala-lang" % "scala-compiler" % scalaVersion.value
case v => throw new Exception(s"Scala version $v not recognised")
})
)

lazy val redactedTests = (project in file("tests"))
Expand All @@ -80,9 +86,10 @@ lazy val redactedTests = (project in file("tests"))
),
Test / scalacOptions ++= {
val jar = (redactedCompilerPlugin / Compile / packageBin).value
val addPlugin = "-Xplugin:" + jar.getAbsolutePath
val addScala2Plugin = "-Xplugin-require:redacted-plugin"
val addScala3Plugin = "-Xplugin:" + jar.getAbsolutePath
val dummy = "-Jdummy=" + jar.lastModified
Seq(addPlugin, dummy)
Seq(addScala2Plugin, addScala3Plugin, dummy)
}
)

Expand Down
4 changes: 4 additions & 0 deletions plugin/src/main/resources/scalac-plugin.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
<plugin>
<name>redacted-plugin</name>
<classname>io.github.polentino.redacted.RedactedPlugin</classname>
</plugin>
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package io.github.polentino.redacted

import scala.tools.nsc._
import scala.tools.nsc.plugins.Plugin
import scala.tools.nsc.plugins.PluginComponent

final class RedactedPlugin(override val global: Global) extends Plugin {

override val name: String = "redacted-plugin"

override val description: String = "Plugin to prevent leaking sensitive data when logging case classes"

override val components: List[PluginComponent] = List(new RedactedPluginComponent(global))
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
package io.github.polentino.redacted

import scala.tools.nsc.backend.jvm.GenBCode
import scala.tools.nsc.plugins.PluginComponent
import scala.tools.nsc.transform.Transform
import scala.util.Success
import scala.tools.nsc.Global

class RedactedPluginComponent(val global: Global) extends PluginComponent with Transform {

override val phaseName: String = "patch-tostring-component"

override val runsAfter: List[String] = List("parser")

override val runsRightAfter: Option[String] = Some("parser")

import global._

override protected def newTransformer(unit: CompilationUnit): Transformer = ToStringMaskerTransformer

private object ToStringMaskerTransformer extends Transformer {

private val TO_STRING_NAME = "toString"
private val redactedTypeName = TypeName("redacted")

override def transform(tree: Tree): Tree = {
val transformedTree = super.transform(tree)
validate(transformedTree) match {
case None => transformedTree
case Some(validatedClassDef) =>
val maybePatchedClassDef = for {
newToStringBody <- createToStringBody(validatedClassDef)
.withLog(s"couldn't create a valid toString body for ${validatedClassDef.name.decode}")

newToStringMethod <- buildToStringMethod(newToStringBody)
.withLog(s"couldn't create a valid toString body for ${validatedClassDef.name.decode}")

patchedClassDef <- patchCaseClass(validatedClassDef, newToStringMethod)
.withLog(s"couldn't create a valid toString body for ${validatedClassDef.name.decode}")

} yield patchedClassDef

maybePatchedClassDef match {
case Some(patchedClassDef) => patchedClassDef
case None =>
reporter.warning(
tree.pos,
s"""
|Dang, couldn't patch properly ${tree.symbol.nameString} :(
|If you believe this is an error: please report the issue, along with a minimum reproducible example,
|at the following link: https://github.com/polentino/redacted/issues/new .
|
|Thank you 🙏
|""".stripMargin
)
tree
}
}
}

/** Utility method that ensures the current tree being inspected is a case class with at least one parameter
* annotated with `@redacted`.
* @param tree
* the tree to be checked
* @return
* an option containing the validated `ClassDef`, or `None`
*/
private def validate(tree: Tree): Option[global.ClassDef] = for {
caseClassType <- validateTypeDef(tree)
_ <- getRedactedFields(caseClassType)
} yield caseClassType

/** Utility method that checks whether the current tree being inspected corresponds to a case class.
* @param tree
* the tree to be checked
* @return
* an option containing the validated `ClassDef`, or `None`
*/
private def validateTypeDef(tree: Tree): Option[ClassDef] = tree match {
case classDef: ClassDef if classDef.mods.isCase => Some(classDef)
case _ => None
}

/** Utility method that returns all ctor fields annotated with `@redacted`
* @param classDef
* the ClassDef to be checked
* @return
* an Option with the list of all params marked with `@redacted`, or `None` otherwise
*/
private def getRedactedFields(classDef: ClassDef): Option[List[ValDef]] =
classDef.impl.body.collectFirst {
case d: DefDef if d.name.decode == GenBCode.INSTANCE_CONSTRUCTOR_NAME =>
d.vparamss.headOption.fold(List.empty[ValDef])(v => v.filter(_.mods.hasAnnotationNamed(redactedTypeName)))
}

/** Utility method to generate a new `toString` definition based on the parameters marked with `@redacted`.
* @param classDef
* the ClassDef for which we need a dedicated `toString` method
* @return
* the body of the new `toString` method
*/
private def createToStringBody(classDef: ClassDef): scala.util.Try[Tree] = scala.util.Try {
val className = classDef.name.decode
val memberNames = getAllFields(classDef)
val classPrefix = (className + "(").toConstantLiteral
val classSuffix = ")".toConstantLiteral
val commaSymbol = ",".toConstantLiteral
val asterisksSymbol = "***".toConstantLiteral
val concatOperator = TermName("$plus")

val fragments: List[Tree] = memberNames.map(m =>
if (m.mods.hasAnnotationNamed(redactedTypeName)) asterisksSymbol
else Apply(Select(Ident(m.name), TO_STRING_NAME), Nil))

def buildToStringTree(fragments: List[Tree]): Tree = {

def concatAll(l: List[Tree]): List[Tree] = l match {
case Nil => Nil
case head :: Nil => List(head)
case head :: tail => List(head, commaSymbol) ++ concatAll(tail)
}

val res = concatAll(fragments).fold(classPrefix) { case (accumulator, fragment) =>
Apply(Select(accumulator, concatOperator), List(fragment))
}
Apply(Select(res, concatOperator), List(classSuffix))
}

buildToStringTree(fragments)
}

/** Returns all the fields in a case class ctor.
* @param classDef
* the `ClassDef` for which we want to get all if ctor field
* @return
* a list of all the `ValDef`
*/
private def getAllFields(classDef: ClassDef): List[ValDef] =
classDef.impl.body.collectFirst {
case d: DefDef if d.name.decode == GenBCode.INSTANCE_CONSTRUCTOR_NAME => d.vparamss.headOption.getOrElse(Nil)
}.getOrElse(Nil)

/** Build a new `toString` method definition containing the body passed as parameter.
* @param body
* the body of the newly created `toString` method
* @return
* the whole `toString` method definition
*/
private def buildToStringMethod(body: Tree): scala.util.Try[DefDef] = scala.util.Try {
DefDef(Modifiers(Flag.OVERRIDE), TermName(TO_STRING_NAME), Nil, Nil, TypeTree(), body)
}

/** Utility method that adds a new method definition to an existing `ClassDef` body.
* @param classDef
* the class that needs to be patched
* @param newToStringMethod
* the new method that will be included in the `ClassDef` passed as first parameter
* @return
* the patched `ClassDef`
*/
private def patchCaseClass(classDef: ClassDef, newToStringMethod: Tree): scala.util.Try[ClassDef] =
scala.util.Try {
val newBody = classDef.impl.body :+ newToStringMethod
val newImpl = classDef.impl.copy(body = newBody)
classDef.copy(impl = newImpl)
}

// utility extension classes

private implicit class AstOps(s: String) {
def toConstantLiteral: Literal = Literal(Constant(s))
}

private implicit class TryOps[Out](opt: scala.util.Try[Out]) {

def withLog(message: String): Option[Out] = opt match {
case Success(value) => Some(value)
case _ => reporter.echo(message); None
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import io.github.polentino.redacted.phases._
class RedactedPlugin extends StandardPlugin {
override def init(options: List[String]): List[PluginPhase] = List(PatchToString())

override def name: String = "Redacted"
override def name: String = "redacted-plugin"

override def description: String = "Plugin to prevent leaking sensitive data when logging case classes"
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ object AstOps {

def redactedFields: List[String] = {
val redactedType = redactedSymbol
symbol.primaryConstructor.paramSymss.flatten.collect {
case s if s.annotations.exists(_.matches(redactedType)) => s.name.toString
symbol.primaryConstructor.paramSymss.headOption.fold(List.empty[String]) { params =>
params
.filter(_.annotations.exists(_.matches(redactedType)))
.map(_.name.toString)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ object PluginOps {
*/
def createToStringBody(tree: tpd.TypeDef)(using Context): Try[tpd.Tree] = Try {
val className = tree.name.toString
val memberNames = tree.symbol.primaryConstructor.paramSymss.flatten
val memberNames = tree.symbol.primaryConstructor.paramSymss.headOption.getOrElse(Nil)
val annotationSymbol = redactedSymbol
val classPrefix = (className + "(").toConstantLiteral
val classSuffix = ")".toConstantLiteral
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ final case class PatchToString() extends PluginPhase {
case None => tree
case Some(validatedTree) =>
val maybeNewTypeDef = for {

template <- getTreeTemplate(validatedTree)
.withLog(s"can't extract proper `tpd.Template` from ${tree.name}")

Expand All @@ -36,23 +35,24 @@ final case class PatchToString() extends PluginPhase {
.withLog(s"couldn't patch ${tree.name} template into ${tree.name} typedef")
} yield result

maybeNewTypeDef match
maybeNewTypeDef match {
case Some(newTypeDef) => newTypeDef
case None =>
report.warning(
s"""
|Dang, couldn't patch properly ${tree.name} :(
|If you believe this is an error: please report the issue, along with a minimum reproducible example,
|at the following link: https://github.com/polentino/redacted/issues/new .
|
|Thank you 🙏
|""".stripMargin,
|Dang, couldn't patch properly ${tree.name} :(
|If you believe this is an error: please report the issue, along with a minimum reproducible example,
|at the following link: https://github.com/polentino/redacted/issues/new .
|
|Thank you 🙏
|""".stripMargin,
tree.srcPos
)
tree
}
}
}

object PatchToString {
final val name: String = "PatchToString"
final val name: String = "patch-tostring-phase"
}
Loading

0 comments on commit 288a99d

Please sign in to comment.