diff --git a/free/src/main/scala/cats/free/Free.scala b/free/src/main/scala/cats/free/Free.scala index 8d029fae51..b07127b180 100644 --- a/free/src/main/scala/cats/free/Free.scala +++ b/free/src/main/scala/cats/free/Free.scala @@ -54,6 +54,20 @@ sealed abstract class Free[S[_], A] extends Product with Serializable { } } + /** + * A combination of step and fold. + */ + private[free] final def foldStep[B]( + onPure: A => B, + onSuspend: S[A] => B, + onFlatMapped: ((S[X], X => Free[S, A]) forSome { type X }) => B + ): B = this.step match { + case Pure(a) => onPure(a) + case Suspend(a) => onSuspend(a) + case FlatMapped(Suspend(fa), f) => onFlatMapped((fa, f)) + case _ => sys.error("FlatMapped should be right associative after step") + } + /** * Run to completion, using a function that extracts the resumption * from its suspension functor. @@ -161,7 +175,7 @@ sealed abstract class Free[S[_], A] extends Product with Serializable { "Free(...)" } -object Free { +object Free extends FreeInstances { /** * Return from the computation with the given value. @@ -250,3 +264,56 @@ object Free { def flatMap[A, B](a: Free[S, A])(f: A => Free[S, B]): Free[S, B] = a.flatMap(f) } } + +private trait FreeFoldable[F[_]] extends Foldable[Free[F, ?]] { + + implicit def F: Foldable[F] + + override final def foldLeft[A, B](fa: Free[F, A], b: B)(f: (B, A) => B): B = + fa.foldStep( + a => f(b, a), + fa => F.foldLeft(fa, b)(f), + { case (fx, g) => F.foldLeft(fx, b)((bb, x) => foldLeft(g(x), bb)(f)) } + ) + + override final def foldRight[A, B](fa: Free[F, A], lb: Eval[B])(f: (A, Eval[B]) => Eval[B]): Eval[B] = + fa.foldStep( + a => f(a, lb), + fa => F.foldRight(fa, lb)(f), + { case (fx, g) => F.foldRight(fx, lb)( (a, lbb) => foldRight(g(a), lbb)(f)) } + ) +} + +private trait FreeTraverse[F[_]] extends Traverse[Free[F, ?]] with FreeFoldable[F] { + implicit def TraversableF: Traverse[F] + + def F: Foldable[F] = TraversableF + + override final def traverse[G[_], A, B](fa: Free[F, A])(f: A => G[B])(implicit G: Applicative[G]): G[Free[F, B]] = + fa.resume match { + case Right(a) => G.map(f(a))(Free.pure(_)) + case Left(ffreeA) => G.map(TraversableF.traverse(ffreeA)(traverse(_)(f)))(Free.roll(_)) + } + + // Override Traverse's map to use Free's map for better performance + override final def map[A, B](fa: Free[F, A])(f: A => B): Free[F, B] = fa.map(f) +} + +sealed private[free] abstract class FreeInstances { + + implicit def catsFreeFoldableForFree[F[_]]( + implicit + foldableF: Foldable[F] + ): Foldable[Free[F, ?]] = + new FreeFoldable[F] { + val F = foldableF + } + + implicit def catsFreeTraverseForFree[F[_]]( + implicit + traversableF: Traverse[F] + ): Traverse[Free[F, ?]] = + new FreeTraverse[F] { + val TraversableF = traversableF + } +} diff --git a/free/src/test/scala/cats/free/FreeTests.scala b/free/src/test/scala/cats/free/FreeTests.scala index 4097b36e10..0eda77a71e 100644 --- a/free/src/test/scala/cats/free/FreeTests.scala +++ b/free/src/test/scala/cats/free/FreeTests.scala @@ -3,7 +3,7 @@ package free import cats.arrow.FunctionK import cats.data.EitherK -import cats.laws.discipline.{CartesianTests, MonadTests, SerializableTests} +import cats.laws.discipline.{CartesianTests, FoldableTests, MonadTests, SerializableTests, TraverseTests} import cats.laws.discipline.arbitrary.catsLawsArbitraryForFn0 import cats.tests.CatsSuite @@ -18,6 +18,19 @@ class FreeTests extends CatsSuite { checkAll("Free[Option, ?]", MonadTests[Free[Option, ?]].monad[Int, Int, Int]) checkAll("Monad[Free[Option, ?]]", SerializableTests.serializable(Monad[Free[Option, ?]])) + locally { + implicit val instance = Free.catsFreeFoldableForFree[Option] + + checkAll("Free[Option, ?]", FoldableTests[Free[Option,?]].foldable[Int,Int]) + checkAll("Foldable[Free[Option,?]]", SerializableTests.serializable(Foldable[Free[Option,?]])) + } + + locally { + implicit val instance = Free.catsFreeTraverseForFree[Option] + checkAll("Free[Option,?]", TraverseTests[Free[Option,?]].traverse[Int, Int, Int, Int, Option, Option]) + checkAll("Traverse[Free[Option,?]]", SerializableTests.serializable(Traverse[Free[Option,?]])) + } + test("toString is stack-safe") { val r = Free.pure[List, Int](333) val rr = (1 to 1000000).foldLeft(r)((r, _) => r.map(_ + 1))