Skip to content

Commit

Permalink
Simplify FunctionK.lift macro
Browse files Browse the repository at this point in the history
Add `FunctionK.liftFunction` that uses an abstract type member
to emulate a polymorphic function type.

Because `liftFunction` has worse type inference,
keep the macro based `lift` that now delegates to `liftFunction`,
instead of creating a new `FunctionK` instance every time.
Also accept eta-expanded functions in the macro implementation.
  • Loading branch information
joroKr21 committed Jul 21, 2020
1 parent d664b05 commit bc1b7b3
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package cats
package arrow

import scala.language.experimental.macros
import scala.reflect.macros.blackbox.Context
import scala.reflect.macros.blackbox

private[arrow] class FunctionKMacroMethods {

Expand Down Expand Up @@ -31,72 +31,41 @@ private[arrow] class FunctionKMacroMethods {

private[arrow] object FunctionKMacros {

def lift[F[_], G[_]](c: Context)(
def lift[F[_], G[_]](c: blackbox.Context)(
f: c.Expr[(F[α] => G[α]) forSome { type α }]
)(implicit
evF: c.WeakTypeTag[F[_]],
evG: c.WeakTypeTag[G[_]]
): c.Expr[FunctionK[F, G]] =
)(implicit evF: c.WeakTypeTag[F[Any]], evG: c.WeakTypeTag[G[Any]]): c.Expr[FunctionK[F, G]] =
c.Expr[FunctionK[F, G]](new Lifter[c.type](c).lift[F, G](f.tree))
// ^^note: extra space after c.type to appease scalastyle

private[this] class Lifter[C <: Context](val c: C) {
private class Lifter[C <: blackbox.Context](val c: C) {
import c.universe._

def lift[F[_], G[_]](tree: Tree)(implicit
evF: c.WeakTypeTag[F[_]],
evG: c.WeakTypeTag[G[_]]
): Tree =
unblock(tree) match {
case q"($param) => $trans[..$typeArgs](${arg: Ident})" if param.name == arg.name =>
typeArgs
.collect { case tt: TypeTree => tt }
.find(tt => Option(tt.original).isDefined)
.foreach { param =>
c.abort(param.pos,
s"type parameter $param must not be supplied when lifting function $trans to FunctionK"
def lift[F[_], G[_]](tree: Tree)(implicit evF: c.WeakTypeTag[F[Any]], evG: c.WeakTypeTag[G[Any]]): Tree = {
def liftFunction(function: Tree): Tree =
function match {
case q"($param) => $trans[..$typeArgs]($arg)" if param.symbol == arg.symbol =>
for (typeArg @ TypeTree() <- typeArgs) if (typeArg.original != null) {
c.abort(
typeArg.pos,
s"type parameter $typeArg must not be supplied when lifting function $trans to FunctionK"
)
}

val F = punchHole(evF.tpe)
val G = punchHole(evG.tpe)
val F = typeConstructorOf[F[Any]]
val G = typeConstructorOf[G[Any]]
q"${reify(FunctionK)}.liftFunction[$F, $G]($trans(_))"

q"""
new _root_.cats.arrow.FunctionK[$F, $G] {
def apply[A](fa: $F[A]): $G[A] = $trans(fa)
case other =>
c.abort(other.pos, s"Unexpected tree $other when lifting to FunctionK")
}
"""
case other =>
c.abort(other.pos, s"Unexpected tree $other when lifting to FunctionK")
}

private[this] def unblock(tree: Tree): Tree =
tree match {
case Block(Nil, expr) => expr
case _ => tree
}

private[this] def punchHole(tpe: Type): Tree =
tpe match {
case PolyType(undet :: Nil, underlying: TypeRef) =>
val α = TypeName("α")
def rebind(typeRef: TypeRef): Tree =
if (typeRef.sym == undet) tq""
else {
val args = typeRef.args.map {
case ref: TypeRef => rebind(ref)
case arg => tq"$arg"
}
tq"${typeRef.sym}[..$args]"
}
val rebound = rebind(underlying)
tq"""({type λ[] = $rebound})#λ"""
case TypeRef(pre, sym, Nil) =>
tq"$sym"
case _ =>
c.abort(c.enclosingPosition, s"Unexpected type $tpe when lifting to FunctionK")
case Block(Nil, expr) => liftFunction(expr)
case Block(stats, expr) => Block(stats, liftFunction(expr))
case other => liftFunction(other)
}
}

private def typeConstructorOf[A: WeakTypeTag]: Type =
weakTypeOf[A].typeConstructor.etaExpand
}

}
16 changes: 16 additions & 0 deletions core/src/main/scala/cats/arrow/FunctionK.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,26 @@ trait FunctionK[F[_], G[_]] extends Serializable { self =>
}

object FunctionK extends FunctionKMacroMethods {
protected type τ[F[_], G[_]]

/**
* The identity transformation of `F` to `F`
*/
def id[F[_]]: FunctionK[F, F] = new FunctionK[F, F] { def apply[A](fa: F[A]): F[A] = fa }

/**
* Lifts function `f` of `F[A] => G[A]` into a `FunctionK[F, G]`.
*
* {{{
* def headOption[A](list: List[A]): Option[A] = list.headOption
* val lifted = FunctionK.liftFunction[List, Option](headOption)
* }}}
*
* Note: The weird `τ[F, G]` parameter is there to compensate for
* the lack of polymorphic function types in Scala 2.
*/
def liftFunction[F[_], G[_]](f: F[τ[F, G]] => G[τ[F, G]]): FunctionK[F, G] =
new FunctionK[F, G] {
def apply[A](fa: F[A]): G[A] = f.asInstanceOf[F[A] => G[A]](fa)
}
}
23 changes: 19 additions & 4 deletions tests/src/test/scala/cats/tests/FunctionKSuite.scala
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package cats.tests

import cats.Id
import cats.arrow.FunctionK
import cats.data.EitherK
import cats.data.NonEmptyList
import cats.data.{EitherK, NonEmptyList}
import cats.laws.discipline.arbitrary._
import cats.{Applicative, Id}

class FunctionKSuite extends CatsSuite {
type OptionOfNel[+A] = Option[NonEmptyList[A]]

val listToOption = new FunctionK[List, Option] { def apply[A](a: List[A]): Option[A] = a.headOption }
val listToVector = new FunctionK[List, Vector] { def apply[A](a: List[A]): Vector[A] = a.toVector }
Expand Down Expand Up @@ -92,12 +92,27 @@ class FunctionKSuite extends CatsSuite {
}

test("lift compound unary") {
val fNelFromList = FunctionK.lift[List, λ[α => Option[NonEmptyList[α]]]](NonEmptyList.fromList _)
val fNelFromList = FunctionK.lift[List, OptionOfNel](NonEmptyList.fromList)
forAll { (a: List[String]) =>
fNelFromList(a) should ===(NonEmptyList.fromList(a))
}
}

test("lift eta-expanded function") {
val fSomeNel = FunctionK.lift[NonEmptyList, OptionOfNel](Applicative[Option].pure)
forAll { (a: NonEmptyList[Int]) =>
fSomeNel(a) should ===(Some(a))
}
}

test("lift a function directly") {
def headOption[A](list: List[A]): Option[A] = list.headOption
val fHeadOption = FunctionK.liftFunction[List, Option](headOption)
forAll { (a: List[Int]) =>
fHeadOption(a) should ===(a.headOption)
}
}

{ // lifting concrete types should fail to compile
def sample[A](option: Option[A]): List[A] = option.toList
assertTypeError("FunctionK.lift(sample[String])")
Expand Down

0 comments on commit bc1b7b3

Please sign in to comment.