diff --git a/core/src/main/scala-2.x/src/main/scala/cats/arrow/FunctionKMacros.scala b/core/src/main/scala-2.x/src/main/scala/cats/arrow/FunctionKMacros.scala index 6e79d704e24..5e301ed3e57 100644 --- a/core/src/main/scala-2.x/src/main/scala/cats/arrow/FunctionKMacros.scala +++ b/core/src/main/scala-2.x/src/main/scala/cats/arrow/FunctionKMacros.scala @@ -2,7 +2,7 @@ package cats package arrow import scala.language.experimental.macros -import scala.reflect.macros.blackbox.Context +import scala.reflect.macros.blackbox private[arrow] class FunctionKMacroMethods { @@ -31,72 +31,41 @@ private[arrow] class FunctionKMacroMethods { private[arrow] object FunctionKMacros { - def lift[F[_], G[_]](c: Context)( + def lift[F[_], G[_]](c: blackbox.Context)( f: c.Expr[(F[α] => G[α]) forSome { type α }] - )(implicit - evF: c.WeakTypeTag[F[_]], - evG: c.WeakTypeTag[G[_]] - ): c.Expr[FunctionK[F, G]] = + )(implicit evF: c.WeakTypeTag[F[Any]], evG: c.WeakTypeTag[G[Any]]): c.Expr[FunctionK[F, G]] = c.Expr[FunctionK[F, G]](new Lifter[c.type](c).lift[F, G](f.tree)) - // ^^note: extra space after c.type to appease scalastyle - private[this] class Lifter[C <: Context](val c: C) { + private class Lifter[C <: blackbox.Context](val c: C) { import c.universe._ - def lift[F[_], G[_]](tree: Tree)(implicit - evF: c.WeakTypeTag[F[_]], - evG: c.WeakTypeTag[G[_]] - ): Tree = - unblock(tree) match { - case q"($param) => $trans[..$typeArgs](${arg: Ident})" if param.name == arg.name => - typeArgs - .collect { case tt: TypeTree => tt } - .find(tt => Option(tt.original).isDefined) - .foreach { param => - c.abort(param.pos, - s"type parameter $param must not be supplied when lifting function $trans to FunctionK" + def lift[F[_], G[_]](tree: Tree)(implicit evF: c.WeakTypeTag[F[Any]], evG: c.WeakTypeTag[G[Any]]): Tree = { + def liftFunction(function: Tree): Tree = + function match { + case q"($param) => $trans[..$typeArgs]($arg)" if param.symbol == arg.symbol => + for (typeArg @ TypeTree() <- typeArgs) if (typeArg.original != null) { + c.abort( + typeArg.pos, + s"type parameter $typeArg must not be supplied when lifting function $trans to FunctionK" ) } - val F = punchHole(evF.tpe) - val G = punchHole(evG.tpe) + val F = typeConstructorOf[F[Any]] + val G = typeConstructorOf[G[Any]] + q"${reify(FunctionK)}.liftFunction[$F, $G]($trans(_))" - q""" - new _root_.cats.arrow.FunctionK[$F, $G] { - def apply[A](fa: $F[A]): $G[A] = $trans(fa) + case other => + c.abort(other.pos, s"Unexpected tree $other when lifting to FunctionK") } - """ - case other => - c.abort(other.pos, s"Unexpected tree $other when lifting to FunctionK") - } - private[this] def unblock(tree: Tree): Tree = tree match { - case Block(Nil, expr) => expr - case _ => tree - } - - private[this] def punchHole(tpe: Type): Tree = - tpe match { - case PolyType(undet :: Nil, underlying: TypeRef) => - val α = TypeName("α") - def rebind(typeRef: TypeRef): Tree = - if (typeRef.sym == undet) tq"$α" - else { - val args = typeRef.args.map { - case ref: TypeRef => rebind(ref) - case arg => tq"$arg" - } - tq"${typeRef.sym}[..$args]" - } - val rebound = rebind(underlying) - tq"""({type λ[$α] = $rebound})#λ""" - case TypeRef(pre, sym, Nil) => - tq"$sym" - case _ => - c.abort(c.enclosingPosition, s"Unexpected type $tpe when lifting to FunctionK") + case Block(Nil, expr) => liftFunction(expr) + case Block(stats, expr) => Block(stats, liftFunction(expr)) + case other => liftFunction(other) } + } + private def typeConstructorOf[A: WeakTypeTag]: Type = + weakTypeOf[A].typeConstructor.etaExpand } - } diff --git a/core/src/main/scala/cats/arrow/FunctionK.scala b/core/src/main/scala/cats/arrow/FunctionK.scala index f2f6457ae13..c1a14893a76 100644 --- a/core/src/main/scala/cats/arrow/FunctionK.scala +++ b/core/src/main/scala/cats/arrow/FunctionK.scala @@ -62,10 +62,26 @@ trait FunctionK[F[_], G[_]] extends Serializable { self => } object FunctionK extends FunctionKMacroMethods { + protected type τ[F[_], G[_]] /** * The identity transformation of `F` to `F` */ def id[F[_]]: FunctionK[F, F] = new FunctionK[F, F] { def apply[A](fa: F[A]): F[A] = fa } + /** + * Lifts function `f` of `F[A] => G[A]` into a `FunctionK[F, G]`. + * + * {{{ + * def headOption[A](list: List[A]): Option[A] = list.headOption + * val lifted = FunctionK.liftFunction[List, Option](headOption) + * }}} + * + * Note: The weird `τ[F, G]` parameter is there to compensate for + * the lack of polymorphic function types in Scala 2. + */ + def liftFunction[F[_], G[_]](f: F[τ[F, G]] => G[τ[F, G]]): FunctionK[F, G] = + new FunctionK[F, G] { + def apply[A](fa: F[A]): G[A] = f.asInstanceOf[F[A] => G[A]](fa) + } } diff --git a/tests/src/test/scala/cats/tests/FunctionKSuite.scala b/tests/src/test/scala/cats/tests/FunctionKSuite.scala index db065f851bd..6512c58101e 100644 --- a/tests/src/test/scala/cats/tests/FunctionKSuite.scala +++ b/tests/src/test/scala/cats/tests/FunctionKSuite.scala @@ -1,12 +1,12 @@ package cats.tests -import cats.Id import cats.arrow.FunctionK -import cats.data.EitherK -import cats.data.NonEmptyList +import cats.data.{EitherK, NonEmptyList} import cats.laws.discipline.arbitrary._ +import cats.{Applicative, Id} class FunctionKSuite extends CatsSuite { + type OptionOfNel[+A] = Option[NonEmptyList[A]] val listToOption = new FunctionK[List, Option] { def apply[A](a: List[A]): Option[A] = a.headOption } val listToVector = new FunctionK[List, Vector] { def apply[A](a: List[A]): Vector[A] = a.toVector } @@ -92,12 +92,27 @@ class FunctionKSuite extends CatsSuite { } test("lift compound unary") { - val fNelFromList = FunctionK.lift[List, λ[α => Option[NonEmptyList[α]]]](NonEmptyList.fromList _) + val fNelFromList = FunctionK.lift[List, OptionOfNel](NonEmptyList.fromList) forAll { (a: List[String]) => fNelFromList(a) should ===(NonEmptyList.fromList(a)) } } + test("lift eta-expanded function") { + val fSomeNel = FunctionK.lift[NonEmptyList, OptionOfNel](Applicative[Option].pure) + forAll { (a: NonEmptyList[Int]) => + fSomeNel(a) should ===(Some(a)) + } + } + + test("lift a function directly") { + def headOption[A](list: List[A]): Option[A] = list.headOption + val fHeadOption = FunctionK.liftFunction[List, Option](headOption) + forAll { (a: List[Int]) => + fHeadOption(a) should ===(a.headOption) + } + } + { // lifting concrete types should fail to compile def sample[A](option: Option[A]): List[A] = option.toList assertTypeError("FunctionK.lift(sample[String])")