diff --git a/be2-scala/linter/input/src/main/scala/fix/ArraysInFormat.scala b/be2-scala/linter/input/src/main/scala/fix/ArraysInFormat.scala index 62bb06201b..9e6f63a8d5 100644 --- a/be2-scala/linter/input/src/main/scala/fix/ArraysInFormat.scala +++ b/be2-scala/linter/input/src/main/scala/fix/ArraysInFormat.scala @@ -9,10 +9,15 @@ object ArraysInFormat { val str = f"Here are my cool elements ${array}" // assert: ArraysInFormat val str2 = s"Here are my cool elements ${array}" // assert: ArraysInFormat String.format("Here are my cool elements %d", array) // assert: ArraysInFormat + "Here are my cool elements %d".format(array) // assert: ArraysInFormat + val str3 = "Here are my cool elements %d" + str3.format(array) // assert: ArraysInFormat String.format("Here are my cool elements %d", array) /* assert: ArraysInFormat ^^^^^ Array passed to format / interpolate string */ + + "Here are my cool elements %d".format(13) // scalafix: ok; } } diff --git a/be2-scala/linter/input/src/main/scala/fix/EmptyInterpolatedString.scala b/be2-scala/linter/input/src/main/scala/fix/EmptyInterpolatedString.scala index 681086aff9..7f5646334d 100644 --- a/be2-scala/linter/input/src/main/scala/fix/EmptyInterpolatedString.scala +++ b/be2-scala/linter/input/src/main/scala/fix/EmptyInterpolatedString.scala @@ -6,7 +6,10 @@ package fix object EmptyInterpolatedString { def test(): Unit = { print(f"Here's my cute interpolation") // assert: EmptyInterpolatedString - print(s"Here's my amazing interpolationg") // assert: EmptyInterpolatedString - print(String.format("I'm hungry!")) // assert: EmptyInterpolatedString + print(s"Here's my amazing interpolation") // assert: EmptyInterpolatedString + String.format("I'm hungry!") // assert: EmptyInterpolatedString + val str = "I'm hungry!" + str.format() // assert: EmptyInterpolatedString + "I'm hungry!".format() // assert: EmptyInterpolatedString } } diff --git a/be2-scala/linter/input/src/main/scala/fix/IllegalFormatString.scala b/be2-scala/linter/input/src/main/scala/fix/IllegalFormatString.scala new file mode 100644 index 0000000000..d209d14033 --- /dev/null +++ b/be2-scala/linter/input/src/main/scala/fix/IllegalFormatString.scala @@ -0,0 +1,36 @@ +/* +rule = IllegalFormatString + */ +package fix + +object IllegalFormatString { + def test(): Unit = { + val name: String = "John" + val age: Integer = 30 + val illegalFormatString1 = "%s is %d years old, %d" + String.format(illegalFormatString1, name, age) // assert: IllegalFormatString + + illegalFormatString1.format(name, age) // assert: IllegalFormatString + + "%s is %d years old, %d".format(name, age) // assert: IllegalFormatString + + "%5.5q".format("sam") // assert: IllegalFormatString + + "% s".format("sam") // assert: IllegalFormatString + + "%qs".format("sam") // assert: IllegalFormatString + + "%.-5s".format("sam") // assert: IllegalFormatString + + "%.s".format("sam") // assert: IllegalFormatString + + "% "one", 2 -> "two") + numMap1.get(1).getOrElse("unknown") // assert: MapGetAndGetOrElse + + + val numMap2 = scala.collection.mutable.Map("one" -> 1, "two" -> 2) + numMap2.get("one").getOrElse(-1) // assert: MapGetAndGetOrElse + + Map("John" -> "Smith", "Peter" -> "Rabbit").get("Sarah").getOrElse("-") // assert: MapGetAndGetOrElse + + } + +} diff --git a/be2-scala/linter/input/src/main/scala/fix/NanComparison.scala b/be2-scala/linter/input/src/main/scala/fix/NanComparison.scala new file mode 100644 index 0000000000..6189182dd2 --- /dev/null +++ b/be2-scala/linter/input/src/main/scala/fix/NanComparison.scala @@ -0,0 +1,29 @@ +/* +rule = NanComparison + */ +package fix + +object NanComparison { + + def test() = { + val d = 0.5d + d == Double.NaN // assert: NanComparison + Double.NaN == d // assert: NanComparison + d.equals(Double.NaN) // assert: NanComparison + Double.NaN.equals(d) // assert: NanComparison + + val f = 0.5f + f == Double.NaN // assert: NanComparison + Double.NaN == f // assert: NanComparison + + val g = Double.NaN + f == g // assert: NanComparison + g == f // assert: NanComparison + + d == g // assert: NanComparison + g == d // assert: NanComparison + d.equals(g) // assert: NanComparison + g.equals(d) // assert: NanComparison + } + +} diff --git a/be2-scala/linter/input/src/main/scala/fix/OptionGet.scala b/be2-scala/linter/input/src/main/scala/fix/OptionGet.scala new file mode 100644 index 0000000000..33a5c5dfb1 --- /dev/null +++ b/be2-scala/linter/input/src/main/scala/fix/OptionGet.scala @@ -0,0 +1,13 @@ +/* +rule = OptionGet + */ +package fix + +object OptionGet { + def test() = { + val o = Option("olivia") + o.get // assert: OptionGet + + Option("layla").get // assert: OptionGet + } +} diff --git a/be2-scala/linter/input/src/main/scala/fix/OptionSize.scala b/be2-scala/linter/input/src/main/scala/fix/OptionSize.scala new file mode 100644 index 0000000000..f8b9130b76 --- /dev/null +++ b/be2-scala/linter/input/src/main/scala/fix/OptionSize.scala @@ -0,0 +1,13 @@ +/* +rule = OptionSize + */ +package fix + +object OptionSize { + def test() = { + val o = Option("olivia") + o.size // assert: OptionSize + + Option("layla").size // assert: OptionSize + } +} diff --git a/be2-scala/linter/input/src/main/scala/fix/StripMarginOnRegex.scala b/be2-scala/linter/input/src/main/scala/fix/StripMarginOnRegex.scala new file mode 100644 index 0000000000..d60007cf5f --- /dev/null +++ b/be2-scala/linter/input/src/main/scala/fix/StripMarginOnRegex.scala @@ -0,0 +1,13 @@ +/* +rule = StripMarginOnRegex + */ +package fix + +object StripMarginOnRegex { + def test() = { + val regex = "match|this".stripMargin.r // assert: StripMarginOnRegex + "match|that".stripMargin.r // assert: StripMarginOnRegex + "match_this".stripMargin.r // scalafix: ok; + "match|this".r // scalafix: ok; + } +} diff --git a/be2-scala/linter/input/src/main/scala/fix/TryGet.scala b/be2-scala/linter/input/src/main/scala/fix/TryGet.scala new file mode 100644 index 0000000000..e280c6f68e --- /dev/null +++ b/be2-scala/linter/input/src/main/scala/fix/TryGet.scala @@ -0,0 +1,15 @@ +/* +rule = TryGet + */ +package fix + +import scala.util.Try + +object TryGet { + def test() = { + val o = Try { println("mojave") } + o.get // assert: TryGet + + Try { println("ghost") }.get // assert: TryGet + } +} diff --git a/be2-scala/linter/project/TargetAxis.scala b/be2-scala/linter/project/TargetAxis.scala new file mode 100644 index 0000000000..e47512ebad --- /dev/null +++ b/be2-scala/linter/project/TargetAxis.scala @@ -0,0 +1,43 @@ +import sbt._ +import sbt.internal.ProjectMatrix +import sbtprojectmatrix.ProjectMatrixPlugin.autoImport._ + +/** Use on ProjectMatrix rows to tag an affinity to a custom scalaVersion */ +case class TargetAxis(scalaVersion: String) extends VirtualAxis.WeakAxis { + + private val scalaBinaryVersion = CrossVersion.binaryScalaVersion(scalaVersion) + + override val idSuffix = s"Target${scalaBinaryVersion.replace('.', '_')}" + override val directorySuffix = s"target$scalaBinaryVersion" +} + +object TargetAxis { + + private def targetScalaVersion(virtualAxes: Seq[VirtualAxis]): String = + virtualAxes.collectFirst { case a: TargetAxis => a.scalaVersion }.get + + /** When invoked on a ProjectMatrix with a TargetAxis, lookup the project generated by `matrix` with a scalaVersion matching the one declared in that TargetAxis, and resolve `key`. + */ + def resolve[T]( + matrix: ProjectMatrix, + key: TaskKey[T] + ): Def.Initialize[Task[T]] = + Def.taskDyn { + val sv = targetScalaVersion(virtualAxes.value) + val project = matrix.finder().apply(sv) + Def.task((project / key).value) + } + + /** When invoked on a ProjectMatrix with a TargetAxis, lookup the project generated by `matrix` with a scalaVersion matching the one declared in that TargetAxis, and resolve `key`. + */ + def resolve[T]( + matrix: ProjectMatrix, + key: SettingKey[T] + ): Def.Initialize[T] = + Def.settingDyn { + val sv = targetScalaVersion(virtualAxes.value) + val project = matrix.finder().apply(sv) + Def.setting((project / key).value) + } + +} diff --git a/be2-scala/linter/rules/src/main/resources/META-INF/services/scalafix.v1.Rule b/be2-scala/linter/rules/src/main/resources/META-INF/services/scalafix.v1.Rule index f5dfd9b08c..0b8e8f0e46 100644 --- a/be2-scala/linter/rules/src/main/resources/META-INF/services/scalafix.v1.Rule +++ b/be2-scala/linter/rules/src/main/resources/META-INF/services/scalafix.v1.Rule @@ -3,3 +3,13 @@ fix.CatchNpe fix.ComparingFloatingTypes fix.EmptyInterpolatedString fix.ImpossibleOptionSizeCondition +fix.IllegalFormatString +fix.IncorrectNumberOfArgsToFormat +fix.IncorrectlyNamedExceptions +fix.LonelySealedTrait +fix.MapGetAndGetOrElse +fix.NanComparison +fix.OptionGet +fix.OptionSize +fix.StripMarginOnRegex +fix.TryGet diff --git a/be2-scala/linter/rules/src/main/scala/fix/ArraysInFormat.scala b/be2-scala/linter/rules/src/main/scala/fix/ArraysInFormat.scala index ec7c74225b..b47c47a238 100644 --- a/be2-scala/linter/rules/src/main/scala/fix/ArraysInFormat.scala +++ b/be2-scala/linter/rules/src/main/scala/fix/ArraysInFormat.scala @@ -20,21 +20,20 @@ case class ArraysInFormatDiag(array: Tree) extends Diagnostic { class ArraysInFormat extends SemanticRule("ArraysInFormat") { - override def fix(implicit doc: SemanticDocument): Patch = { - - def rule(args: List[Stat]) = { - args.collect { - case t @ Term.Name(_) => - t.symbol.info match { - case Some(symInfo) => symInfo.signature match { - case ValueSignature(TypeRef(_, symbol, _)) if SymbolMatcher.exact("scala/Array#").matches(symbol) => Patch.lint(ArraysInFormatDiag(t)) - case _ => Patch.empty - } - case _ => Patch.empty + private def rule(args: List[Stat])(implicit doc: SemanticDocument) = { + args.collect { + case t @ Term.Name(_) => + t.symbol.info match { + case Some(symInfo) => symInfo.signature match { + case ValueSignature(TypeRef(_, symbol, _)) if SymbolMatcher.exact("scala/Array#").matches(symbol) => Patch.lint(ArraysInFormatDiag(t)) + case _ => Patch.empty } - } + case _ => Patch.empty + } } + } + override def fix(implicit doc: SemanticDocument): Patch = { doc.tree.collect { case Term.Apply.After_4_6_0(Term.Select(_, Term.Name("format")), Term.ArgClause(args, _)) => rule(args) case Term.Interpolate(_, _, args) => rule(args.flatMap { case Term.Block(stats) => stats }) // Args in Term.Interpolate come in Term.Block requiring to extract the stats diff --git a/be2-scala/linter/rules/src/main/scala/fix/ComparingFloatingTypes.scala b/be2-scala/linter/rules/src/main/scala/fix/ComparingFloatingTypes.scala index af99ed5925..5756839fe8 100644 --- a/be2-scala/linter/rules/src/main/scala/fix/ComparingFloatingTypes.scala +++ b/be2-scala/linter/rules/src/main/scala/fix/ComparingFloatingTypes.scala @@ -21,19 +21,9 @@ case class ComparingFloatingTypesDiag(floats: Tree) extends Diagnostic { class ComparingFloatingTypes extends SemanticRule("ComparingFloatingTypes") { override def fix(implicit doc: SemanticDocument): Patch = { - def getType(term: Term): Symbol = { - term.symbol.info match { - case Some(symInfo) => symInfo.signature match { - case ValueSignature(TypeRef(_, symbol, _)) => symbol - case _ => null - } - case _ => null - } - } def isFloatOrDouble(term: Term): Boolean = { - val floatOrDoubleMatcher = SymbolMatcher.exact("scala/Float#", "scala/Double#") - floatOrDoubleMatcher.matches(getType(term)) + Util.matchType(term, "scala/Float", "scala/Double") } doc.tree.collect { diff --git a/be2-scala/linter/rules/src/main/scala/fix/EmptyInterpolatedString.scala b/be2-scala/linter/rules/src/main/scala/fix/EmptyInterpolatedString.scala index adb9a9dd2f..e722477660 100644 --- a/be2-scala/linter/rules/src/main/scala/fix/EmptyInterpolatedString.scala +++ b/be2-scala/linter/rules/src/main/scala/fix/EmptyInterpolatedString.scala @@ -21,7 +21,7 @@ case class EmptyInterpolatedStringDiag(string: Tree) extends Diagnostic { class EmptyInterpolatedString extends SemanticRule("EmptyInterpolatedString") { override def fix(implicit doc: SemanticDocument): Patch = { doc.tree.collect { - case t @ Term.Apply.After_4_6_0(Term.Select(_, Term.Name("format")), Term.ArgClause(List(Lit.String(_)), _)) => Patch.lint(EmptyInterpolatedStringDiag(t)) + case t @ Term.Apply.After_4_6_0(Term.Select(_, Term.Name("format")), Term.ArgClause(List(Lit.String(_)), _) | Term.ArgClause(Nil, _)) => Patch.lint(EmptyInterpolatedStringDiag(t)) case t @ Term.Interpolate(_, _, Nil) => Patch.lint(EmptyInterpolatedStringDiag(t)) }.asPatch } diff --git a/be2-scala/linter/rules/src/main/scala/fix/IllegalFormatString.scala b/be2-scala/linter/rules/src/main/scala/fix/IllegalFormatString.scala new file mode 100644 index 0000000000..ae985feca3 --- /dev/null +++ b/be2-scala/linter/rules/src/main/scala/fix/IllegalFormatString.scala @@ -0,0 +1,73 @@ +/* +rule = IllegalFormatString + */ +package fix + +import fix.Util.{findDefinition, findDefinitions, findDefinitionsOrdered} +import scalafix.lint.LintSeverity + +import scala.meta._ +import scalafix.v1._ + +import java.util.{IllegalFormatException, MissingFormatArgumentException, UnknownFormatConversionException} + +case class IllegalFormatStringDiag(string: Tree) extends Diagnostic { + override def message: String = "Illegal format string" + + override def severity: LintSeverity = LintSeverity.Error + + override def explanation: String = "An unchecked exception will be thrown when a format string contains an illegal syntax or a format specifier that is incompatible with the given arguments" + + override def position: Position = string.pos +} + +class IllegalFormatString extends SemanticRule("IllegalFormatString") { + + //Term parameter is simply used to display the rule at the correct place + private def rule(term: Term, value: String, args: List[Any]): Patch = { + try value.format(args: _*) + catch { + case _: IllegalFormatException | _: MissingFormatArgumentException | _: UnknownFormatConversionException => return Patch.lint(IllegalFormatStringDiag(term)) + } + Patch.empty + } + + override def fix(implicit doc: SemanticDocument): Patch = { + + def getMappedArgs(args: List[Term]): List[Any] = { + (findDefinitionsOrdered(doc.tree, args) ++ args).collect { case Lit(value) => value } + } + + doc.tree.collect { + case t @ Term.Apply.After_4_6_0(Term.Select(qual, Term.Name("format")), Term.ArgClause(args, _)) => + qual match { + case Term.Name("String") => + //This case corresponds to String.format(format, args) + //String.format argclause has format string as first argument and rest are arguments, + // we can safely assume first argument is a string if we are doing a String.format() + args match { + case Lit.String(format) :: rest => rule(t, format, getMappedArgs(rest)) + case Term.Name(_) :: _ => + val mappedArgs = getMappedArgs(args) + rule(t, mappedArgs.head.toString, mappedArgs.tail) + case _ => Patch.empty + } + case strName @ Term.Name(_) => + //This case corresponds to a string declared as a val or var and used in a format call, e.g. + // val myStr = "Hello %s" + // myStr.format("world") + val str = findDefinition(doc.tree, strName) + str match { + case Lit.String(value) => rule(t, value, getMappedArgs(args)) + case _ => Patch.empty + } + case Lit.String(value) => + //This case corresponds to a string literal used in a format call, e.g. "Hello %s".format("world") + rule(t, value, getMappedArgs(args)) + case _ => Patch.empty + } + } + }.asPatch + + +} diff --git a/be2-scala/linter/rules/src/main/scala/fix/ImpossibleOptionSizeCondition.scala b/be2-scala/linter/rules/src/main/scala/fix/ImpossibleOptionSizeCondition.scala index 021316aca8..a0d8aa8cc4 100644 --- a/be2-scala/linter/rules/src/main/scala/fix/ImpossibleOptionSizeCondition.scala +++ b/be2-scala/linter/rules/src/main/scala/fix/ImpossibleOptionSizeCondition.scala @@ -22,20 +22,10 @@ class ImpossibleOptionSizeCondition extends SemanticRule("ImpossibleOptionSizeCo override def fix(implicit doc: SemanticDocument): Patch = { - def getType(term: Term): Symbol = { - term.symbol.info match { - case Some(symInfo) => symInfo.signature match { - case ValueSignature(TypeRef(_, symbol, _)) => symbol - case _ => null - } - case _ => null - } - } - doc.tree.collect { case t @ Term.ApplyInfix.After_4_6_0(Term.Select(qual, Term.Name("size")), Term.Name(">") | Term.Name(">="), _, Term.ArgClause(List(comparedValue), _)) - if SymbolMatcher.exact("scala/Option#", "scala/Some#").matches(getType(qual)) => + if Util.matchType(qual, "scala/Option", "scala/Some") => comparedValue match { case Lit.Int(actualValue) if actualValue >= 1 => Patch.lint(ImpossibleOptionSizeConditionDiag(t)) case _ => Patch.empty diff --git a/be2-scala/linter/rules/src/main/scala/fix/IncorrectNumberOfArgsToFormat.scala b/be2-scala/linter/rules/src/main/scala/fix/IncorrectNumberOfArgsToFormat.scala new file mode 100644 index 0000000000..b617d6cda2 --- /dev/null +++ b/be2-scala/linter/rules/src/main/scala/fix/IncorrectNumberOfArgsToFormat.scala @@ -0,0 +1,63 @@ +/* +rule = IncorrectNumberOfArgsToFormat + */ +package fix + +import fix.Util.findDefinition +import scalafix.lint.LintSeverity +import scalafix.v1._ + +import scala.meta._ + +case class IncorrectNumberOfArgsToFormatDiag(string: Tree) extends Diagnostic { + override def message: String = "Incorrect number of arguments to format" + + override def severity: LintSeverity = LintSeverity.Error + + override def explanation: String = "The number of arguments passed to String.format doesn't correspond to the number of fields in the format string." + + override def position: Position = string.pos +} + +class IncorrectNumberOfArgsToFormat extends SemanticRule("IncorrectNumberOfArgsToFormat") { + + private val argRegex = "%((\\d+\\$)?[-#+ 0,(\\<]*?\\d?(\\.\\d+)?[tT]?[a-zA-Z]|%)".r + + private def rule(format: String, args: List[Any], t: Term): Patch = { + val argCount = argRegex + .findAllIn(format) + .matchData + .count(m => !doesNotTakeArguments(m.matched)) + + if (argCount != args.size) Patch.lint(IncorrectNumberOfArgsToFormatDiag(t)) else Patch.empty + } + + override def fix(implicit doc: SemanticDocument): Patch = { + doc.tree.collect { + case t @ Term.Apply.After_4_6_0(Term.Select(qual, Term.Name("format")), Term.ArgClause(args, _)) => + qual match { + case Term.Name("String") => + args match { + case Lit.String(format) :: rest => rule(format, rest, t) + case (head @ Term.Name(_)) :: rest => + val format = findDefinition(doc.tree, head) + format match { + case Lit.String(value) => rule(value, rest, t) + } + case _ => Patch.empty + } + case strName @ Term.Name(_) => val str = findDefinition(doc.tree, strName) + str match { + case Lit.String(value) => rule(value, args, t) + case _ => Patch.empty + } + case Lit.String(value) => rule(value, args, t) + + } + case _ => Patch.empty + }.asPatch + } + + private def doesNotTakeArguments(formatSpecifier: String): Boolean = + formatSpecifier == "%%" || formatSpecifier == "%n" +} diff --git a/be2-scala/linter/rules/src/main/scala/fix/IncorrectlyNamedExceptions.scala b/be2-scala/linter/rules/src/main/scala/fix/IncorrectlyNamedExceptions.scala new file mode 100644 index 0000000000..a2eaf1e8cf --- /dev/null +++ b/be2-scala/linter/rules/src/main/scala/fix/IncorrectlyNamedExceptions.scala @@ -0,0 +1,55 @@ +/* +rule = IncorrectlyNamedExceptions + */ +package fix + +import scalafix.lint.LintSeverity + +import scala.meta._ +import scalafix.v1._ + +case class IncorrectlyNamedExceptionsDiag(exception: Tree) extends Diagnostic { + override def message: String = "Incorrectly named exceptions" + + override def severity: LintSeverity = LintSeverity.Error + + override def explanation: String = "Class named exception does not derive from Exception / class derived from Exception is not named *Exception." + + override def position: Position = exception.pos +} + +class IncorrectlyNamedExceptions extends SemanticRule("IncorrectlyNamedExceptions") { + + private def inheritsFromException(symbol: Symbol)(implicit doc: SemanticDocument): Boolean = { + symbol.info match { + case Some(info) => + // Check if the current type contains an Exception in its name + info.signature match { + case ClassSignature(_, _, _, _) if info.displayName.contains("Exception") => true + case ClassSignature(_, parents, _, _) => + // Recursively check parent types + parents.map(_.asInstanceOf[TypeRef].symbol).exists(inheritsFromException(_)) + case TypeSignature(_, TypeRef(_, symbol1, _), TypeRef(_, symbol2, _)) => + // Check if inherits java.lang.Exception + val matcher = SymbolMatcher.exact("java/lang/Exception#") + matcher.matches(symbol1) || matcher.matches(symbol2) + case _ => false + } + case None => false + } + } + + override def fix(implicit doc: SemanticDocument): Patch = { + doc.tree.collect { + case cl @ Defn.Class.After_4_6_0(_, Type.Name(name), _, _, _) => + cl.symbol.info.get.signature match { + case ClassSignature(_, parents, _, _) => + if (!name.contains("Exception") && parents.map(_.asInstanceOf[TypeRef].symbol).exists(inheritsFromException)) + Patch.lint(IncorrectlyNamedExceptionsDiag(cl)) + else Patch.empty + case _ => Patch.empty + } + }.asPatch + } + +} diff --git a/be2-scala/linter/rules/src/main/scala/fix/LonelySealedTrait.scala b/be2-scala/linter/rules/src/main/scala/fix/LonelySealedTrait.scala new file mode 100644 index 0000000000..4d54a72053 --- /dev/null +++ b/be2-scala/linter/rules/src/main/scala/fix/LonelySealedTrait.scala @@ -0,0 +1,72 @@ +/* +rule = LonelySealedTrait + */ +package fix + +import scalafix.lint.LintSeverity + +import scala.meta._ +import scalafix.v1._ + +import scala.collection.mutable + +case class LonelySealedTraitDiag(t: Tree) extends Diagnostic { + override def message: String = "Lonely sealed trait" + + override def severity: LintSeverity = LintSeverity.Error + + override def explanation: String = "A sealed trait that is not extended is considered dead code." + + override def position: Position = t.pos +} + +class LonelySealedTrait extends SemanticRule("LonelySealedTrait") { + + private def findBaseClasses(parent: String, sealedTraitsHierarchy: mutable.Map[String, Set[String]]): Set[String] = { + sealedTraitsHierarchy.get(parent) match { + case Some(baseClasses) => + baseClasses ++ baseClasses.flatMap(findBaseClasses(_, sealedTraitsHierarchy)) + case None => + Set.empty[String] + } + } + + + override def fix(implicit doc: SemanticDocument): Patch = { + val sealedTraits = mutable.Map[String, Defn]() + val parents = mutable.Set[String]() + val sealedTraitsHierarchy = mutable.Map[String, Set[String]]() + + + // Use display name to handle the case A[B] where A is a trait and B is a type parameter + def prettyInits(inits: List[Init]): List[String] = inits.collect { case i if i.symbol.info.isDefined => i.symbol.info.get.displayName } + + def handleSealedTrait(inits: List[Init], cl: Defn, clname: String): Unit = { + if(inits.isEmpty) sealedTraitsHierarchy += (clname -> Set()) + else { + prettyInits(inits).foreach(i => sealedTraitsHierarchy += (i -> (sealedTraitsHierarchy.getOrElse(i, Set()) + clname))) + } + sealedTraits += (clname -> cl) + } + + + doc.tree.traverse { + case tr @ Defn.Trait.After_4_6_0(mods, name, _, _, Template.After_4_4_0(_, inits, _, _, _)) + if mods.exists(m => m.toString.equals("sealed")) => handleSealedTrait(inits, tr, name.value) + case cls @ Defn.Class.After_4_6_0(mods, name, _, _, Template.After_4_4_0(_, inits, _, _, _)) + if mods.exists(m => m.toString.equals("sealed")) => handleSealedTrait(inits, cls, name.value) + case cl @ Defn.Class.After_4_6_0(_, _, _, _, _) => cl.symbol.info.get.signature match { + case ClassSignature(_, pars, _, _) => pars.foreach(p => parents += p.toString()) + case _ => () // Class should have class signature + } + case Defn.Object(_, _, Template.After_4_4_0(_, inits, _, _, _)) => parents ++= prettyInits(inits) + } + + sealedTraits.collect({ + case (name, cl) if !parents.contains(name) + && parents.intersect(findBaseClasses(name,sealedTraitsHierarchy)).isEmpty => Patch.lint(LonelySealedTraitDiag(cl)) + //Either this sealed trait is directly extended or one of its sealed base classes / traits is being extended + case _ => Patch.empty + }).asPatch + } +} diff --git a/be2-scala/linter/rules/src/main/scala/fix/MapGetAndGetOrElse.scala b/be2-scala/linter/rules/src/main/scala/fix/MapGetAndGetOrElse.scala new file mode 100644 index 0000000000..01bb8a8f8d --- /dev/null +++ b/be2-scala/linter/rules/src/main/scala/fix/MapGetAndGetOrElse.scala @@ -0,0 +1,31 @@ +/* +rule = MapGetAndGetOrElse + */ +package fix + +import scalafix.lint.LintSeverity + +import scala.meta._ +import scalafix.v1._ + +case class MapGetAndGetOrElseDiag(map: Tree) extends Diagnostic { + override def message: String = "Using of Map.get().getOrElse instead of Map.getOrElse()" + + override def severity: LintSeverity = LintSeverity.Error + + override def explanation: String = "Map.get(key).getOrElse(value) can be replaced with Map.getOrElse(key, value), which is more concise." + + override def position: Position = map.pos +} + +class MapGetAndGetOrElse extends SemanticRule("MapGetAndGetOrElse") { + + override def fix(implicit doc: SemanticDocument): Patch = { + doc.tree.collect { + case Term.Apply.After_4_6_0(Term.Select(Term.Apply.After_4_6_0(Term.Select(qual, Term.Name("get")), _), Term.Name("getOrElse")), _) + if Util.matchType(qual, "scala/collection/immutable/Map", "scala/collection/mutable/Map", "scala/Predef.Map") + => Patch.lint(MapGetAndGetOrElseDiag(qual)) + case _ => Patch.empty + }.asPatch + } +} diff --git a/be2-scala/linter/rules/src/main/scala/fix/NanComparison.scala b/be2-scala/linter/rules/src/main/scala/fix/NanComparison.scala new file mode 100644 index 0000000000..36b83a72a3 --- /dev/null +++ b/be2-scala/linter/rules/src/main/scala/fix/NanComparison.scala @@ -0,0 +1,50 @@ +/* +rule = NanComparison +*/ +package fix + +import scalafix.lint.LintSeverity +import scalafix.v1._ +import scala.meta._ + +case class NanComparisonDiag(tree: Tree) extends Diagnostic { + override def message: String = "Comparison with NaN" + + override def severity: LintSeverity = LintSeverity.Error + + override def explanation: String = "NaN comparison will always fail. Use value.isNaN instead." + + override def position: Position = tree.pos +} + +class NanComparison extends SemanticRule("NanComparison") { + + override def fix(implicit doc: SemanticDocument): Patch = { + + def matcher(l: Any, r: Any, t: Tree): Patch = { + // Either we have definitions to extract or we have direct terms and matching is the same + // => matcher function avoids repetition. Type of arguments is any since findDefinition will return Any, + // but it doesn't matter since we do pattern matching + (l, r) match { + case (Term.Select(Term.Name("Double"), Term.Name("NaN")), _) | (_, Term.Select(Term.Name("Double"), Term.Name("NaN"))) => Patch.lint(NanComparisonDiag(t)) + case _ => Patch.empty + } + } + + def rule(lhs: Term, rhs: Term, t: Term): Patch = { + (lhs,rhs) match { + case (Term.Name(_), Term.Name(_)) => Util.findDefinitions(doc.tree, Set(lhs, rhs)) match { + case List((_, ld), (_, rd)) => matcher(ld, rd, t) + case _ => Patch.empty + } + // Extract definitions and match in the case both are variables + case _ => matcher(lhs, rhs, t) + } + } + + doc.tree.collect { + case t @ Term.ApplyInfix.After_4_6_0(lhs, Term.Name("==") | Term.Name("!="), _, Term.ArgClause(List(rhs), _)) => rule(lhs, rhs, t) + case t @ Term.Apply.After_4_6_0(Term.Select(lhs, Term.Name("equals")), Term.ArgClause(List(right), _)) => rule(lhs, right, t) + }.asPatch + } +} diff --git a/be2-scala/linter/rules/src/main/scala/fix/OptionGet.scala b/be2-scala/linter/rules/src/main/scala/fix/OptionGet.scala new file mode 100644 index 0000000000..d53d889e91 --- /dev/null +++ b/be2-scala/linter/rules/src/main/scala/fix/OptionGet.scala @@ -0,0 +1,28 @@ +/* +rule = OptionGet +*/ +package fix + +import scalafix.lint.LintSeverity +import scalafix.v1._ +import scala.meta._ + +case class OptionGetDiag(tree: Tree) extends Diagnostic { + override def message: String = "Use of Option.Get" + + override def severity: LintSeverity = LintSeverity.Error + + override def explanation: String = "Using Option.get defeats the purpose of using Option in the first place. Use the following instead: Option.getOrElse, Option.fold, pattern matching or don't take the value out of the container and map over it to transform it" + + override def position: Position = tree.pos +} + +class OptionGet extends SemanticRule("OptionGet") { + + override def fix(implicit doc: SemanticDocument): Patch = { + doc.tree.collect { + case t @ Term.Select(qual, Term.Name("get")) if Util.matchType(qual, "scala/Option") => Patch.lint(OptionGetDiag(t)) + case _ => Patch.empty + }.asPatch + } +} diff --git a/be2-scala/linter/rules/src/main/scala/fix/OptionSize.scala b/be2-scala/linter/rules/src/main/scala/fix/OptionSize.scala new file mode 100644 index 0000000000..6715995b06 --- /dev/null +++ b/be2-scala/linter/rules/src/main/scala/fix/OptionSize.scala @@ -0,0 +1,28 @@ +/* +rule = OptionSize +*/ +package fix + +import scalafix.lint.LintSeverity +import scalafix.v1._ +import scala.meta._ + +case class OptionSizeDiag(tree: Tree) extends Diagnostic { + override def message: String = "Prefer Option.isDefined instead of Option.size" + + override def severity: LintSeverity = LintSeverity.Error + + override def explanation: String = "Prefer to use Option.isDefined, Option.isEmpty or Option.nonEmpty instead of Option.size." + + override def position: Position = tree.pos +} + +class OptionSize extends SemanticRule("OptionSize") { + + override def fix(implicit doc: SemanticDocument): Patch = { + doc.tree.collect { + case t @ Term.Select(qual, Term.Name("size")) if Util.matchType(qual, "scala/Option") => Patch.lint(OptionSizeDiag(t)) + case _ => Patch.empty + }.asPatch + } +} diff --git a/be2-scala/linter/rules/src/main/scala/fix/StripMarginOnRegex.scala b/be2-scala/linter/rules/src/main/scala/fix/StripMarginOnRegex.scala new file mode 100644 index 0000000000..694ddab256 --- /dev/null +++ b/be2-scala/linter/rules/src/main/scala/fix/StripMarginOnRegex.scala @@ -0,0 +1,28 @@ +/* +rule = StripMarginOnRegex +*/ +package fix + +import scalafix.lint.LintSeverity +import scalafix.v1._ +import scala.meta._ + +case class StripMarginOnRegexDiag(tree: Tree) extends Diagnostic { + override def message: String = "Strip margin on regex" + + override def severity: LintSeverity = LintSeverity.Error + + override def explanation: String = "Strip margin will strip | from regex - possible corrupted regex." + + override def position: Position = tree.pos +} + +class StripMarginOnRegex extends SemanticRule("StripMarginOnRegex") { + + override def fix(implicit doc: SemanticDocument): Patch = { + doc.tree.collect { + case t @ Term.Select(Term.Select(Lit.String(string), Term.Name("stripMargin")), Term.Name("r")) if string.contains('|') => Patch.lint(StripMarginOnRegexDiag(t)) + case _ => Patch.empty + }.asPatch + } +} diff --git a/be2-scala/linter/rules/src/main/scala/fix/TryGet.scala b/be2-scala/linter/rules/src/main/scala/fix/TryGet.scala new file mode 100644 index 0000000000..8e07f1fa2f --- /dev/null +++ b/be2-scala/linter/rules/src/main/scala/fix/TryGet.scala @@ -0,0 +1,28 @@ +/* +rule = TryGet +*/ +package fix + +import scalafix.lint.LintSeverity +import scalafix.v1._ +import scala.meta._ + +case class TryGetDiag(tree: Tree) extends Diagnostic { + override def message: String = "Use of Try.Get" + + override def severity: LintSeverity = LintSeverity.Error + + override def explanation: String = "Using Try.get is unsafe because it will throw the underlying exception in case of a Failure. Use the following instead: Try.getOrElse, Try.fold, pattern matching or don't take the value out of the container and map over it to transform it." + + override def position: Position = tree.pos +} + +class TryGet extends SemanticRule("TryGet") { + + override def fix(implicit doc: SemanticDocument): Patch = { + doc.tree.collect { + case t @ Term.Select(qual, Term.Name("get")) if Util.matchType(qual, "scala/util/Try") => Patch.lint(TryGetDiag(t)) + case _ => Patch.empty + }.asPatch + } +} diff --git a/be2-scala/linter/rules/src/main/scala/fix/Util.scala b/be2-scala/linter/rules/src/main/scala/fix/Util.scala new file mode 100644 index 0000000000..bd55e2f27c --- /dev/null +++ b/be2-scala/linter/rules/src/main/scala/fix/Util.scala @@ -0,0 +1,49 @@ +package fix + +import scalafix.v1._ + +import scala.meta._ + +object Util { + def getType(term: Term)(implicit doc: SemanticDocument): Symbol = { + term.symbol.info match { + case Some(symInfo) => symInfo.signature match { + case ValueSignature(TypeRef(_, symbol, _)) => symbol + case _ => Symbol.None + } + case _ => Symbol.None + } + } + + def matchType(term: Term, symbols: String*)(implicit doc: SemanticDocument): Boolean = { + val symbolMatcher = SymbolMatcher.normalized(symbols: _*) + symbolMatcher.matches(term.symbol) || symbolMatcher.matches(getType(term)) + // Checks the term symbol matches that of the symbol (i.e. when we use the type directly), + // or if the type of the variable matches. + } + + def findDefinition(tree: Tree, name: Term): Any = { + tree.collect { + case Defn.Val(_, List(Pat.Var(varName)), _, value) + if varName.value.equals(name.toString) => value + case Defn.Var.After_4_7_2(_, List(Pat.Var(varName)), _, value) + if varName == name => value + }.headOption.orNull + } + + def findDefinitions(tree: Tree, nameSet: Set[Term]): List[(Term, Any)] = { + tree.collect { + case Defn.Val(_, List(Pat.Var(varName)), _, value) if nameSet.exists(_.toString().equals(varName.value)) => nameSet.find(_.toString().equals(varName.value)).get -> value + case Defn.Var.After_4_7_2(_, List(Pat.Var(varName)), _, value) if nameSet.exists(_.toString().equals(varName.value)) => nameSet.find(_.toString().equals(varName.value)).get -> value + } + } + + // Finds multiple definitions in one tree traversal + def findDefinitionsMap(tree: Tree, nameSet: Set[Term]): Map[Term, Any] = { + findDefinitions(tree, nameSet).toMap + } + + def findDefinitionsOrdered(tree: Tree, nameSet: List[Term]): List[Any] = { + findDefinitions(tree, nameSet.toSet).sortBy { case (term, _) => nameSet.indexOf(term) }.map { case (_, value) => value } + } +}