From 2eb3099f45c1a4574b9d35e1dbf58e4dd5e347f3 Mon Sep 17 00:00:00 2001 From: aaron levin Date: Tue, 22 Aug 2017 00:00:20 +0200 Subject: [PATCH] Add Foldable and Traversable instances for Free --- free/src/main/scala/cats/free/Free.scala | 70 ++++++++++++++++++- free/src/test/scala/cats/free/FreeTests.scala | 35 +++++++++- 2 files changed, 103 insertions(+), 2 deletions(-) diff --git a/free/src/main/scala/cats/free/Free.scala b/free/src/main/scala/cats/free/Free.scala index 8d029fae517..bf895907f49 100644 --- a/free/src/main/scala/cats/free/Free.scala +++ b/free/src/main/scala/cats/free/Free.scala @@ -54,6 +54,20 @@ sealed abstract class Free[S[_], A] extends Product with Serializable { } } + /** + * A combination of step and fold. + */ + private[free] final def foldStep[B]( + onPure: A => B, + onSuspend: S[A] => B, + onFlatMapped: ((S[X], X => Free[S, A]) forSome { type X }) => B + ): B = this.step match { + case Pure(a) => onPure(a) + case Suspend(a) => onSuspend(a) + case FlatMapped(Suspend(fa), f) => onFlatMapped((fa, f)) + case _ => sys.error("FlatMapped should be right associative after step") + } + /** * Run to completion, using a function that extracts the resumption * from its suspension functor. @@ -161,7 +175,7 @@ sealed abstract class Free[S[_], A] extends Product with Serializable { "Free(...)" } -object Free { +object Free extends FreeInstances { /** * Return from the computation with the given value. @@ -250,3 +264,57 @@ object Free { def flatMap[A, B](a: Free[S, A])(f: A => Free[S, B]): Free[S, B] = a.flatMap(f) } } + +private trait FreeFoldable[F[_]] extends Foldable[Free[F, ?]] { + + implicit def F: Foldable[F] + + override final def foldLeft[A, B](fa: Free[F, A], b: B)(f: (B, A) => B): B = + fa.foldStep( + a => f(b, a), + fa => F.foldLeft(fa, b)(f), + { case (fx, g) => F.foldLeft(fx, b)((bb, x) => foldLeft(g(x), bb)(f)) } + ) + + override final def foldRight[A, B](fa: Free[F, A], lb: Eval[B])(f: (A, Eval[B]) => Eval[B]): Eval[B] = + fa.foldStep( + a => f(a, lb), + fa => F.foldRight(fa, lb)(f), + { case (fx, g) => F.foldRight(fx, lb)( (a, lbb) => foldRight(g(a), lbb)(f)) } + ) +} + +private trait FreeTraverse[F[_]] extends Traverse[Free[F, ?]] with FreeFoldable[F] { + implicit def TraversableF: Traverse[F] + + def F: Foldable[F] = TraversableF + + override final def traverse[G[_], A, B](fa: Free[F, A])(f: A => G[B])(implicit G: Applicative[G]): G[Free[F, B]] = + fa.resume match { + case Right(a) => G.map(f(a))(Free.pure(_)) + case Left(ffreeA) => G.map(TraversableF.traverse(ffreeA)(traverse(_)(f)))(Free.roll(_)) + } + + // Override Traverse's map to use Free's map for better performance + override final def map[A, B](fa: Free[F, A])(f: A => B): Free[F, B] = fa.map(f) +} + +sealed private[free] abstract class FreeInstances { + + implicit def catsFreeFoldableForFree[F[_]]( + implicit + foldableF: Foldable[F] + ): Foldable[Free[F, ?]] = + new FreeFoldable[F] { + val F = foldableF + } + + implicit def catsFreeTraverseForFree[F[_]]( + implicit + traversableF: Traverse[F] + ): Traverse[Free[F, ?]] = + new FreeTraverse[F] { + val TraversableF = traversableF + val FunctorF = traversableF + } +} diff --git a/free/src/test/scala/cats/free/FreeTests.scala b/free/src/test/scala/cats/free/FreeTests.scala index 4097b36e105..784a4cfa471 100644 --- a/free/src/test/scala/cats/free/FreeTests.scala +++ b/free/src/test/scala/cats/free/FreeTests.scala @@ -3,7 +3,7 @@ package free import cats.arrow.FunctionK import cats.data.EitherK -import cats.laws.discipline.{CartesianTests, MonadTests, SerializableTests} +import cats.laws.discipline.{CartesianTests, FoldableTests, MonadTests, SerializableTests, TraverseTests} import cats.laws.discipline.arbitrary.catsLawsArbitraryForFn0 import cats.tests.CatsSuite @@ -18,6 +18,19 @@ class FreeTests extends CatsSuite { checkAll("Free[Option, ?]", MonadTests[Free[Option, ?]].monad[Int, Int, Int]) checkAll("Monad[Free[Option, ?]]", SerializableTests.serializable(Monad[Free[Option, ?]])) + locally { + implicit val instance = Free.catsFreeFoldableForFree[Option] + + checkAll("Free[Option, ?]", FoldableTests[Free[Option,?]].foldable[Int,Int]) + checkAll("Foldable[Free[Option,?]]", SerializableTests.serializable(Foldable[Free[Option,?]])) + } + + locally { + implicit val instance = Free.catsFreeTraverseForFree[Option] + checkAll("Free[Option,?]", TraverseTests[Free[Option,?]].traverse[Int, Int, Int, Int, Option, Option]) + checkAll("Traverse[Free[Option,?]]", SerializableTests.serializable(Traverse[Free[Option,?]])) + } + test("toString is stack-safe") { val r = Free.pure[List, Int](333) val rr = (1 to 1000000).foldLeft(r)((r, _) => r.map(_ + 1)) @@ -82,6 +95,26 @@ class FreeTests extends CatsSuite { assert(10000 == a(0).foldMap(runner)) } + test("foldRight is stack safe") { + val instance = Free.catsFreeFoldableForFree[Option] + val n = 50000 + val freeOption: Int => Free[Option, Int] = x => Free.pure(x) + val free = (1 to n).foldLeft(freeOption(0))((r, _) => r.flatMap(n => freeOption(n + 1))) + val result = instance.foldRight(free, Eval.now(0))((a, lb) => lb.map(_ + a)).value + + assert(n == result) + } + + test("foldLeft is stack safe") { + val instance = Free.catsFreeFoldableForFree[Option] + val n = 50000 + val freeOption: Int => Free[Option, Int] = x => Free.pure(x) + val free = (1 to n).foldLeft(freeOption(0))((r, _) => r.flatMap(n => freeOption(n + 1))) + val result = instance.foldLeft(free, 0)(_ + _) + + assert(n == result) + } + test(".runTailRec") { val r = Free.pure[List, Int](12358) def recurse(r: Free[List, Int], n: Int): Free[List, Int] =