Skip to content

Commit

Permalink
Simplify FunctionK.lift macro (#3402)
Browse files Browse the repository at this point in the history
Co-authored-by: Lars Hupel <lars.hupel@mytum.de>
  • Loading branch information
joroKr21 and larsrh authored Dec 6, 2020
1 parent 9f1191d commit 8e8be51
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ package cats
package arrow

import scala.language.experimental.macros
import scala.reflect.macros.blackbox.Context
import scala.reflect.macros.blackbox

private[arrow] class FunctionKMacroMethods {
protected type τ[F[_], G[_]]

/**
* Lifts function `f` of `F[A] => G[A]` into a `FunctionK[F, G]`.
Expand All @@ -27,76 +28,61 @@ private[arrow] class FunctionKMacroMethods {
*/
def lift[F[_], G[_]](f: (F[α] => G[α]) forSome { type α }): FunctionK[F, G] =
macro FunctionKMacros.lift[F, G]

/**
* Lifts function `f` of `F[A] => G[A]` into a `FunctionK[F, G]`.
*
* {{{
* def headOption[A](list: List[A]): Option[A] = list.headOption
* val lifted = FunctionK.liftFunction[List, Option](headOption)
* }}}
*
* Note: The weird `τ[F, G]` parameter is there to compensate for
* the lack of polymorphic function types in Scala 2.
*/
def liftFunction[F[_], G[_]](f: F[τ[F, G]] => G[τ[F, G]]): FunctionK[F, G] =
new FunctionK[F, G] {
def apply[A](fa: F[A]): G[A] = f.asInstanceOf[F[A] => G[A]](fa)
}
}

private[arrow] object FunctionKMacros {

def lift[F[_], G[_]](c: Context)(
def lift[F[_], G[_]](c: blackbox.Context)(
f: c.Expr[(F[α] => G[α]) forSome { type α }]
)(implicit
evF: c.WeakTypeTag[F[_]],
evG: c.WeakTypeTag[G[_]]
): c.Expr[FunctionK[F, G]] =
)(implicit evF: c.WeakTypeTag[F[Any]], evG: c.WeakTypeTag[G[Any]]): c.Expr[FunctionK[F, G]] =
c.Expr[FunctionK[F, G]](new Lifter[c.type](c).lift[F, G](f.tree))
// ^^note: extra space after c.type to appease scalastyle

private[this] class Lifter[C <: Context](val c: C) {
private class Lifter[C <: blackbox.Context](val c: C) {
import c.universe._

def lift[F[_], G[_]](tree: Tree)(implicit
evF: c.WeakTypeTag[F[_]],
evG: c.WeakTypeTag[G[_]]
): Tree =
unblock(tree) match {
case q"($param) => $trans[..$typeArgs](${arg: Ident})" if param.name == arg.name =>
typeArgs
.collect { case tt: TypeTree => tt }
.find(tt => Option(tt.original).isDefined)
.foreach { param =>
c.abort(param.pos,
s"type parameter $param must not be supplied when lifting function $trans to FunctionK"
def lift[F[_], G[_]](tree: Tree)(implicit evF: c.WeakTypeTag[F[Any]], evG: c.WeakTypeTag[G[Any]]): Tree = {
def liftFunction(function: Tree): Tree =
function match {
case q"($param) => $trans[..$typeArgs]($arg)" if param.symbol == arg.symbol =>
for (typeArg @ TypeTree() <- typeArgs) if (typeArg.original != null) {
c.abort(
typeArg.pos,
s"type parameter $typeArg must not be supplied when lifting function $trans to FunctionK"
)
}

val F = punchHole(evF.tpe)
val G = punchHole(evG.tpe)
val F = typeConstructorOf[F[Any]]
val G = typeConstructorOf[G[Any]]
q"${reify(FunctionK)}.liftFunction[$F, $G]($trans(_))"

q"""
new _root_.cats.arrow.FunctionK[$F, $G] {
def apply[A](fa: $F[A]): $G[A] = $trans(fa)
case other =>
c.abort(other.pos, s"Unexpected tree $other when lifting to FunctionK")
}
"""
case other =>
c.abort(other.pos, s"Unexpected tree $other when lifting to FunctionK")
}

private[this] def unblock(tree: Tree): Tree =
tree match {
case Block(Nil, expr) => expr
case _ => tree
}

private[this] def punchHole(tpe: Type): Tree =
tpe match {
case PolyType(undet :: Nil, underlying: TypeRef) =>
val α = TypeName("α")
def rebind(typeRef: TypeRef): Tree =
if (typeRef.sym == undet) tq""
else {
val args = typeRef.args.map {
case ref: TypeRef => rebind(ref)
case arg => tq"$arg"
}
tq"${typeRef.sym}[..$args]"
}
val rebound = rebind(underlying)
tq"""({type λ[] = $rebound})#λ"""
case TypeRef(pre, sym, Nil) =>
tq"$sym"
case _ =>
c.abort(c.enclosingPosition, s"Unexpected type $tpe when lifting to FunctionK")
case Block(Nil, expr) => liftFunction(expr)
case Block(stats, expr) => Block(stats, liftFunction(expr))
case other => liftFunction(other)
}
}

private def typeConstructorOf[A: WeakTypeTag]: Type =
weakTypeOf[A].typeConstructor.etaExpand
}

}
1 change: 0 additions & 1 deletion core/src/main/scala/cats/arrow/FunctionK.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,4 @@ object FunctionK extends FunctionKMacroMethods {
* The identity transformation of `F` to `F`
*/
def id[F[_]]: FunctionK[F, F] = new FunctionK[F, F] { def apply[A](fa: F[A]): F[A] = fa }

}
123 changes: 123 additions & 0 deletions tests/src/test/scala-2.x/cats/tests/FunctionKSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package cats.tests

import cats.arrow.FunctionK
import cats.data.{EitherK, NonEmptyList}
import cats.laws.discipline.arbitrary._
import cats.syntax.eq._
import cats.{Applicative, Id}
import org.scalacheck.Prop._

class FunctionKSuite extends CatsSuite {
type OptionOfNel[+A] = Option[NonEmptyList[A]]

val listToOption = new FunctionK[List, Option] { def apply[A](a: List[A]): Option[A] = a.headOption }
val listToVector = new FunctionK[List, Vector] { def apply[A](a: List[A]): Vector[A] = a.toVector }
val optionToList = new FunctionK[Option, List] { def apply[A](a: Option[A]): List[A] = a.toList }

sealed trait Test1Algebra[A] {
def v: A
}

case class Test1[A](v: A) extends Test1Algebra[A]

sealed trait Test2Algebra[A] {
def v: A
}

case class Test2[A](v: A) extends Test2Algebra[A]

val Test1FK = new FunctionK[Test1Algebra, Id] { def apply[A](a: Test1Algebra[A]): A = a.v }
val Test2FK = new FunctionK[Test2Algebra, Id] { def apply[A](a: Test2Algebra[A]): A = a.v }

test("compose") {
forAll { (list: List[Int]) =>
val listToList = optionToList.compose(listToOption)
assert(listToList(list) === list.take(1))
}
}

test("andThen") {
forAll { (list: List[Int]) =>
val listToList = listToOption.andThen(optionToList)
assert(listToList(list) === list.take(1))
}
}

test("id is identity") {
forAll { (list: List[Int]) =>
assert(FunctionK.id[List].apply(list) === list)
}
}

test("or") {
val combinedInterpreter = Test1FK.or(Test2FK)
forAll { (a: Int, b: Int) =>
assert(combinedInterpreter(EitherK.left(Test1(a))) === a)
assert(combinedInterpreter(EitherK.right(Test2(b))) === b)
}
}

test("and") {
val combinedInterpreter = listToOption.and(listToVector)
forAll { (list: List[Int]) =>
val prod = combinedInterpreter(list)
assert(prod.first === list.headOption)
assert(prod.second === list.toVector)
}
}

test("lift simple unary") {
def optionToList[A](option: Option[A]): List[A] = option.toList
val fOptionToList = FunctionK.lift(optionToList _)
forAll { (a: Option[Int]) =>
assert(fOptionToList(a) === optionToList(a))
}

val fO2I: FunctionK[Option, Iterable] = FunctionK.lift(Option.option2Iterable _)
forAll { (a: Option[String]) =>
assert(fO2I(a).toList === Option.option2Iterable(a).toList)
}

val fNelFromListUnsafe = FunctionK.lift(NonEmptyList.fromListUnsafe _)
forAll { (a: NonEmptyList[Int]) =>
assert(fNelFromListUnsafe(a.toList) === NonEmptyList.fromListUnsafe(a.toList))
}
}

test("hygiene") {
trait FunctionK
def optionToList[A](option: Option[A]): List[A] = option.toList
val fOptionToList = cats.arrow.FunctionK.lift(optionToList _)
forAll { (a: Option[Int]) =>
assert(fOptionToList(a) === optionToList(a))
}
}

test("lift compound unary") {
val fNelFromList = FunctionK.lift[List, OptionOfNel](NonEmptyList.fromList)
forAll { (a: List[String]) =>
assert(fNelFromList(a) === NonEmptyList.fromList(a))
}
}

test("lift eta-expanded function") {
val fSomeNel = FunctionK.lift[NonEmptyList, OptionOfNel](Applicative[Option].pure)
forAll { (a: NonEmptyList[Int]) =>
assert(fSomeNel(a) === Some(a))
}
}

test("lift a function directly") {
def headOption[A](list: List[A]): Option[A] = list.headOption
val fHeadOption = FunctionK.liftFunction[List, Option](headOption)
forAll { (a: List[Int]) =>
assert(fHeadOption(a) === a.headOption)
}
}

{ // lifting concrete types should fail to compile
def sample[A](option: Option[A]): List[A] = option.toList
assert(compileErrors("FunctionK.lift(sample[String])").nonEmpty)
assert(compileErrors("FunctionK.lift(sample[Nothing])").nonEmpty)
}
}
67 changes: 0 additions & 67 deletions tests/src/test/scala/cats/tests/FunctionKSuite.scala

This file was deleted.

0 comments on commit 8e8be51

Please sign in to comment.