From de5d911e02f2fb70d07da05ce1623b4d8b9b8d3e Mon Sep 17 00:00:00 2001 From: Tomas Mikula Date: Tue, 31 May 2016 01:00:19 -0400 Subject: [PATCH] MonadRec instances for Eval and StateT. --- core/src/main/scala/cats/Eval.scala | 10 +++- core/src/main/scala/cats/data/StateT.scala | 52 +++++++++++++------ .../src/test/scala/cats/tests/EvalTests.scala | 4 +- .../cats/tests/MonadRecInstancesTests.scala | 11 +++- .../test/scala/cats/tests/StateTTests.scala | 12 ++++- 5 files changed, 67 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/cats/Eval.scala b/core/src/main/scala/cats/Eval.scala index b0970cc808..a56e6215bf 100644 --- a/core/src/main/scala/cats/Eval.scala +++ b/core/src/main/scala/cats/Eval.scala @@ -1,6 +1,7 @@ package cats import scala.annotation.tailrec +import cats.data.Xor import cats.syntax.all._ /** @@ -294,14 +295,19 @@ object Eval extends EvalInstances { private[cats] trait EvalInstances extends EvalInstances0 { - implicit val evalBimonad: Bimonad[Eval] = - new Bimonad[Eval] { + implicit val evalBimonad: Bimonad[Eval] with MonadRec[Eval] = + new Bimonad[Eval] with MonadRec[Eval] { override def map[A, B](fa: Eval[A])(f: A => B): Eval[B] = fa.map(f) def pure[A](a: A): Eval[A] = Now(a) override def pureEval[A](la: Eval[A]): Eval[A] = la def flatMap[A, B](fa: Eval[A])(f: A => Eval[B]): Eval[B] = fa.flatMap(f) def extract[A](la: Eval[A]): A = la.value def coflatMap[A, B](fa: Eval[A])(f: Eval[A] => B): Eval[B] = Later(f(fa)) + def tailRecM[A, B](a: A)(f: A => Eval[A Xor B]): Eval[B] = + f(a).flatMap(_ match { + case Xor.Left(a1) => tailRecM(a1)(f) // recursion OK here, since flatMap is lazy + case Xor.Right(b) => Eval.now(b) + }) } implicit def evalOrder[A: Order]: Order[Eval[A]] = diff --git a/core/src/main/scala/cats/data/StateT.scala b/core/src/main/scala/cats/data/StateT.scala index e8f1406ac0..169e7832d7 100644 --- a/core/src/main/scala/cats/data/StateT.scala +++ b/core/src/main/scala/cats/data/StateT.scala @@ -135,22 +135,9 @@ object StateT extends StateTInstances { StateT(s => F.pure((s, a))) } -private[data] sealed abstract class StateTInstances { - implicit def catsDataMonadStateForStateT[F[_], S](implicit F: Monad[F]): MonadState[StateT[F, S, ?], S] = - new MonadState[StateT[F, S, ?], S] { - def pure[A](a: A): StateT[F, S, A] = - StateT.pure(a) - - def flatMap[A, B](fa: StateT[F, S, A])(f: A => StateT[F, S, B]): StateT[F, S, B] = - fa.flatMap(f) - - val get: StateT[F, S, S] = StateT(a => F.pure((a, a))) - - def set(s: S): StateT[F, S, Unit] = StateT(_ => F.pure((s, ()))) - - override def map[A, B](fa: StateT[F, S, A])(f: A => B): StateT[F, S, B] = - fa.map(f) - } +private[data] sealed abstract class StateTInstances extends StateTInstances1 { + implicit def catsDataMonadStateForStateT[F[_], S](implicit F0: Monad[F]): MonadState[StateT[F, S, ?], S] = + new StateTMonadState[F, S] { implicit def F = F0 } implicit def catsDataLiftForStateT[S]: TransLift.Aux[StateT[?[_], S, ?], Applicative] = new TransLift[StateT[?[_], S, ?]] { @@ -161,6 +148,11 @@ private[data] sealed abstract class StateTInstances { } +private[data] sealed abstract class StateTInstances1 { + implicit def catsDataMonadRecForStateT[F[_], S](implicit F0: MonadRec[F]): MonadRec[StateT[F, S, ?]] = + new StateTMonadRec[F, S] { implicit def F = F0 } +} + // To workaround SI-7139 `object State` needs to be defined inside the package object // together with the type alias. private[data] abstract class StateFunctions { @@ -193,3 +185,31 @@ private[data] abstract class StateFunctions { */ def set[S](s: S): State[S, Unit] = State(_ => (s, ())) } + +private[data] sealed trait StateTMonad[F[_], S] extends Monad[StateT[F, S, ?]] { + implicit def F: Monad[F] + + def pure[A](a: A): StateT[F, S, A] = + StateT.pure(a) + + def flatMap[A, B](fa: StateT[F, S, A])(f: A => StateT[F, S, B]): StateT[F, S, B] = + fa.flatMap(f) + + override def map[A, B](fa: StateT[F, S, A])(f: A => B): StateT[F, S, B] = + fa.map(f) +} + +private[data] sealed trait StateTMonadState[F[_], S] extends MonadState[StateT[F, S, ?], S] with StateTMonad[F, S] { + val get: StateT[F, S, S] = StateT(s => F.pure((s, s))) + + def set(s: S): StateT[F, S, Unit] = StateT(_ => F.pure((s, ()))) +} + +private[data] sealed trait StateTMonadRec[F[_], S] extends MonadRec[StateT[F, S, ?]] with StateTMonad[F, S] { + override implicit def F: MonadRec[F] + + def tailRecM[A, B](a: A)(f: A => StateT[F, S, A Xor B]): StateT[F, S, B] = + StateT[F, S, B](s => F.tailRecM[(S, A), (S, B)]((s, a)) { + case (s, a) => F.map(f(a).run(s)) { case (s, ab) => ab.bimap((s, _), (s, _)) } + }) +} diff --git a/tests/src/test/scala/cats/tests/EvalTests.scala b/tests/src/test/scala/cats/tests/EvalTests.scala index 76fb4f48d0..6e6ddaa44d 100644 --- a/tests/src/test/scala/cats/tests/EvalTests.scala +++ b/tests/src/test/scala/cats/tests/EvalTests.scala @@ -3,7 +3,7 @@ package tests import scala.math.min import cats.laws.ComonadLaws -import cats.laws.discipline.{CartesianTests, BimonadTests, SerializableTests} +import cats.laws.discipline.{BimonadTests, CartesianTests, MonadRecTests, SerializableTests} import cats.laws.discipline.arbitrary._ import cats.kernel.laws.{GroupLaws, OrderLaws} @@ -93,8 +93,10 @@ class EvalTests extends CatsSuite { { implicit val iso = CartesianTests.Isomorphisms.invariant[Eval] checkAll("Eval[Int]", BimonadTests[Eval].bimonad[Int, Int, Int]) + checkAll("Eval[Int]", MonadRecTests[Eval].monadRec[Int, Int, Int]) } checkAll("Bimonad[Eval]", SerializableTests.serializable(Bimonad[Eval])) + checkAll("MonadRec[Eval]", SerializableTests.serializable(MonadRec[Eval])) checkAll("Eval[Int]", GroupLaws[Eval[Int]].group) diff --git a/tests/src/test/scala/cats/tests/MonadRecInstancesTests.scala b/tests/src/test/scala/cats/tests/MonadRecInstancesTests.scala index 649cd60c01..89ad883e0d 100644 --- a/tests/src/test/scala/cats/tests/MonadRecInstancesTests.scala +++ b/tests/src/test/scala/cats/tests/MonadRecInstancesTests.scala @@ -1,7 +1,7 @@ package cats package tests -import cats.data.{OptionT, Xor, XorT} +import cats.data.{OptionT, StateT, Xor, XorT} class MonadRecInstancesTests extends CatsSuite { def tailRecMStackSafety[M[_]](implicit M: MonadRec[M], Eq: Eq[M[Int]]): Unit = { @@ -38,4 +38,13 @@ class MonadRecInstancesTests extends CatsSuite { tailRecMStackSafety[List] } + test("tailRecM stack-safety for Eval") { + tailRecMStackSafety[Eval] + } + + test("tailRecM stack-safety for StateT") { + import StateTTests._ // import implicit Eq[StateT[...]] + tailRecMStackSafety[StateT[Option, Int, ?]] + } + } diff --git a/tests/src/test/scala/cats/tests/StateTTests.scala b/tests/src/test/scala/cats/tests/StateTTests.scala index 677af2ea02..1f7b3aee75 100644 --- a/tests/src/test/scala/cats/tests/StateTTests.scala +++ b/tests/src/test/scala/cats/tests/StateTTests.scala @@ -2,7 +2,7 @@ package cats package tests import cats.kernel.std.tuple._ -import cats.laws.discipline.{CartesianTests, MonadStateTests, SerializableTests} +import cats.laws.discipline.{CartesianTests, MonadRecTests, MonadStateTests, SerializableTests} import cats.data.{State, StateT} import cats.laws.discipline.eq._ import cats.laws.discipline.arbitrary._ @@ -116,14 +116,22 @@ class StateTTests extends CatsSuite { { implicit val iso = CartesianTests.Isomorphisms.invariant[StateT[Option, Int, ?]] + checkAll("StateT[Option, Int, Int]", MonadStateTests[StateT[Option, Int, ?], Int].monadState[Int, Int, Int]) - checkAll("MonadState[StateT[Option, ?, ?], Int]", SerializableTests.serializable(MonadState[StateT[Option, Int, ?], Int])) + checkAll("MonadState[StateT[Option, Int, ?], Int]", SerializableTests.serializable(MonadState[StateT[Option, Int, ?], Int])) + + checkAll("StateT[Option, Int, Int]", MonadRecTests[StateT[Option, Int, ?]].monadRec[Int, Int, Int]) + checkAll("MonadRec[StateT[Option, Int, ?]]", SerializableTests.serializable(MonadRec[StateT[Option, Int, ?]])) } { implicit val iso = CartesianTests.Isomorphisms.invariant[State[Long, ?]] + checkAll("State[Long, ?]", MonadStateTests[State[Long, ?], Long].monadState[Int, Int, Int]) checkAll("MonadState[State[Long, ?], Long]", SerializableTests.serializable(MonadState[State[Long, ?], Long])) + + checkAll("State[Long, ?]", MonadRecTests[State[Long, ?]].monadRec[Int, Int, Int]) + checkAll("MonadRec[State[Long, ?]]", SerializableTests.serializable(MonadRec[State[Long, ?]])) } }