Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix varargs Ops generation #236

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 38 additions & 35 deletions core/src/main/scala/simulacrum/typeclass.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package simulacrum

import scala.annotation.{ compileTimeOnly, StaticAnnotation }
import scala.annotation.{compileTimeOnly, StaticAnnotation}
import scala.language.experimental.macros
import scala.reflect.macros.whitebox.Context
import scala.reflect.api.Trees
import scala.reflect.api.Annotations
import scala.reflect.macros.whitebox

/**
* Annotation that may be applied to methods on a type that is annotated with `@typeclass`.
Expand Down Expand Up @@ -52,7 +50,7 @@ class typeclass(excludeParents: List[String] = Nil, generateAllOps: Boolean = tr
def macroTransform(annottees: Any*): Any = macro TypeClassMacros.generateTypeClass
}

class TypeClassMacros(val c: Context) {
class TypeClassMacros(val c: whitebox.Context) {
import c.universe._

def generateTypeClass(annottees: c.Expr[Any]*): c.Expr[Any] = {
Expand All @@ -68,9 +66,9 @@ class TypeClassMacros(val c: Context) {
TypeDef(fixedMods.asInstanceOf[c.universe.Modifiers], tparam.name, tparam.tparams, tparam.rhs)
}

def trace(s: => String) = {
def trace(s: => String): Unit = {
// Macro paradise seems to always output info statements, even without -verbose
if (sys.props.get("simulacrum.trace").isDefined) c.info(c.enclosingPosition, s, false)
if (sys.props.get("simulacrum.trace").isDefined) c.info(c.enclosingPosition, s, force = false)
}

class RewriteTypeName(from: TypeName, to: TypeName) extends Transformer {
Expand All @@ -81,6 +79,13 @@ class TypeClassMacros(val c: Context) {
}
}

def rewriteTypeName(tree: Tree, old: TypeName, name: TypeName): Tree =
new RewriteTypeName(from = old, to = name).transform(tree)

def rewriteTypeNameOp(newNames: Iterator[TypeName])(tree: Tree, lta: TypeDef): Tree = {
rewriteTypeName(tree, lta.name, newNames.next)
}

class FoldTransformer(transformers: List[Transformer]) extends Transformer {
override def transform(t: Tree): Tree = super.transform(transformers.foldLeft(t)((prev, transformer) => transformer.transform(prev)))
}
Expand All @@ -100,13 +105,14 @@ class TypeClassMacros(val c: Context) {
}

def determineOpsMethodName(sourceMethod: DefDef): List[TermName] = {
val suppress = sourceMethod.mods.annotations.filter { ann =>
val suppress = sourceMethod.mods.annotations.exists { ann =>
val typed = c.typecheck(ann)
typed.tpe.typeSymbol.fullName match {
case "simulacrum.noop" => true
case _ => false
}
}.nonEmpty
}

if (suppress) Nil
else {
def genAlias(alias: String, rest: List[Tree]) = {
Expand All @@ -120,7 +126,7 @@ class TypeClassMacros(val c: Context) {
case q"alias = ${Literal(Constant(alias: Boolean))}" :: _ =>
if (alias) List(sourceMethod.name.toTermName, aliasTermName)
else List(aliasTermName)
case other =>
case _ =>
List(aliasTermName)
}
}
Expand Down Expand Up @@ -168,8 +174,14 @@ class TypeClassMacros(val c: Context) {
else firstParamList.tail :: method.vparamss.tail
}
paramNamess: List[List[Tree]] = {
val original = method.vparamss map { _ map { p => Ident(p.name) } }
original.updated(0, original(0).updated(0, q"self"))
val original = method.vparamss.map { vparams =>
vparams.map {
case ValDef(_, TermName(name), AppliedTypeTree(Select(_, TypeName("<repeated>")), _), _) =>
Typed(Ident(TermName(name)), Ident(typeNames.WILDCARD_STAR))
case p => Ident(p.name)
}
}
original.updated(0, original.head.updated(0, q"self"))
}
rhs = paramNamess.foldLeft(Select(Ident(tcInstanceName), method.name): Tree) { (tree, paramNames) =>
Apply(tree, paramNames)
Expand Down Expand Up @@ -201,11 +213,11 @@ class TypeClassMacros(val c: Context) {
args.zipWithIndex.map {
case (arg, idx) =>
val simpleArgOpt = extract(arg)
(arg, simpleArgOpt, liftedTypeArgs(idx), simpleArgOpt.map(arg equalsStructure Ident(_)).getOrElse(false))
(arg, simpleArgOpt, liftedTypeArgs(idx), simpleArgOpt.exists(arg equalsStructure Ident(_)))
}
}

val skipMethod = !simpleArgs.foldLeft(true)(_ && _._2.isDefined)
val skipMethod = !simpleArgs.forall(_._2.isDefined)

if(skipMethod) List.empty else {
//rewrites all occurrences of any of the args which are defined on the method to the lifted arg
Expand Down Expand Up @@ -234,7 +246,7 @@ class TypeClassMacros(val c: Context) {
val paramNamess: List[List[Tree]] = {
val original = method.vparamss map { _ map { p => Ident(p.name) } }
val replacement = if (equalityEvidences.isEmpty) q"self" else q"self.asInstanceOf[${tparamName.toTypeName}[..$args]]"
original.updated(0, original(0).updated(0, replacement))
original.updated(0, original.head.updated(0, replacement))
}

val mtparamss = if(equalityEvidences.isEmpty) method.tparams.map(t => tq"""${t.name}""").map(rewriteSimpleArgs.transform) else Nil
Expand Down Expand Up @@ -275,7 +287,7 @@ class TypeClassMacros(val c: Context) {
if (abstractTypeMembers.isEmpty) {
tq"${typeClass.name}[${tparam.name}]"
} else {
val refinements = abstractTypeMembers.map { case TypeDef(mods, name, tparams, rhs) =>
val refinements = abstractTypeMembers.map { case TypeDef(_, name, tparams, _) =>
val (namedParams, names) = tparams.map { case TypeDef(pmods, _, ptparams, prhs) =>
val newName = TypeName(c.freshName)
(TypeDef(pmods, newName, ptparams, prhs), newName)
Expand Down Expand Up @@ -319,16 +331,11 @@ class TypeClassMacros(val c: Context) {
def generateAllOps(typeClass: ClassDef, tcInstanceName: TermName, tparam: TypeDef, liftedTypeArgs: List[TypeDef]): ClassDef = {
val tparams = List(tparam) ++ liftedTypeArgs
val tparamNames = tparams.map { _.name }
val tcargs = typeClass.mods.annotations.flatMap { ann =>
val typed = c.typecheck(ann)
if (typed.tpe.typeSymbol.fullName == "simulacrum.typeclass") {
val q"new ${_}(..${args})" = typed
List(args)
} else Nil
}

val typeClassParents: List[TypeName] = typeClass.impl.parents.collect {
case tq"${Ident(parentTypeClassTypeName)}[${_}]" => parentTypeClassTypeName.toTypeName
}

val allOpsParents = typeClassParents collect {
case parent if !(typeClassArguments.parentsToExclude contains parent) =>
tq"${parent.toTermName}.AllOps[..$tparamNames]"
Expand Down Expand Up @@ -375,7 +382,7 @@ class TypeClassMacros(val c: Context) {
(i + 1) -> (rewriteWildcard.transformTypeDefs(List(TypeDef(fixedMods, tname, tpps, rhs))).head :: ts)
}._2.reverse
}
}
}

val tcInstanceName = TermName("typeClassInstance")

Expand Down Expand Up @@ -406,10 +413,7 @@ class TypeClassMacros(val c: Context) {
q"trait $toOpsTraitName { $method }"
}

val nonInheritedOpsConversion = {
val method = generateOpsImplicitConversion(opsTrait.name, TermName(s"to${typeClass.name}Ops"))
q"object nonInheritedOps extends ${toOpsTrait.name}"
}
val nonInheritedOpsConversion = q"object nonInheritedOps extends ${toOpsTrait.name}"

val allOpsConversion = {
val method = generateOpsImplicitConversion(TypeName("AllOps"), TermName(s"toAll${typeClass.name}Ops"))
Expand All @@ -433,13 +437,12 @@ class TypeClassMacros(val c: Context) {
"""

// Rewrite liftedTypeArg.name to something easier to read
val potentialNames = ('A' to 'Z').map(ch => TypeName(ch.toString)).filter(nme => !opsReservedTParamNames.contains(nme))
val potentialNames = ('A' to 'Z')
.toIterator
.map(ch => TypeName(ch.toString))
.filter(nme => !opsReservedTParamNames.contains(nme))

liftedTypeArgs.foldLeft((companion: Tree) -> potentialNames) {
case ((prev, namesLeft), lta) =>
val newName = namesLeft.head
new RewriteTypeName(from = lta.name, to = newName).transform(prev) -> namesLeft.tail
}._1
liftedTypeArgs.foldLeft(companion: Tree) { rewriteTypeNameOp(potentialNames) }
}

def modify(typeClass: ClassDef, companion: Option[ModuleDef]) = {
Expand Down Expand Up @@ -493,7 +496,7 @@ class TypeClassMacros(val c: Context) {
annottees.map(_.tree) match {
case (typeClass: ClassDef) :: Nil => modify(typeClass, None)
case (typeClass: ClassDef) :: (companion: ModuleDef) :: Nil => modify(typeClass, Some(companion))
case other :: Nil =>
case _ :: Nil =>
c.abort(c.enclosingPosition, "@typeclass can only be applied to traits or abstract classes that take 1 type parameter which is either a proper type or a type constructor")
}
}
Expand Down
24 changes: 24 additions & 0 deletions core/src/test/scala/simulacrum/typeclass.scala
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,30 @@ class TypeClassTest extends AnyWordSpec with Matchers {
1 ~ 2 shouldBe 3
}

"supports varargs in adapted methods" in {

@typeclass trait Vargs[T] {
@op("/:", true) def fold[T2](x: T, y: T2*)(f: (T, T2) => T): T
}

case class Sum(total: Int = 0) {
def add(operand: Int): Sum = Sum(total + operand)
}

implicit val sumVargs: Vargs[Sum] = new Vargs[Sum] {
def fold[T2](x: Sum, ys: T2*)(f: (Sum, T2) => Sum): Sum =
ys.foldLeft(x) { (ds, t2) => f(ds, t2) }
}

import Vargs.ops._

val sum = Sum()
val ops = List(1, 2, 3, 4)

sum.fold(ops: _*) { _ add _ } shouldBe Sum(10)
sum./: ((ops ::: ops): _*) { _ add _ } shouldBe Sum(20)
}

"supports suppression of adapter methods" in {
@typeclass trait Sg[A] {
@noop def append(x: A, y: A): A
Expand Down
Loading