Skip to content

Commit

Permalink
Merge pull request #1076 from TomasMikula/state-monadrec
Browse files Browse the repository at this point in the history
MonadRec instances for Eval and StateT.
  • Loading branch information
non committed May 31, 2016
2 parents d3c64d1 + de5d911 commit 6c6b2a8
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 22 deletions.
10 changes: 8 additions & 2 deletions core/src/main/scala/cats/Eval.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cats

import scala.annotation.tailrec
import cats.data.Xor
import cats.syntax.all._

/**
Expand Down Expand Up @@ -294,14 +295,19 @@ object Eval extends EvalInstances {

private[cats] trait EvalInstances extends EvalInstances0 {

implicit val evalBimonad: Bimonad[Eval] =
new Bimonad[Eval] {
implicit val evalBimonad: Bimonad[Eval] with MonadRec[Eval] =
new Bimonad[Eval] with MonadRec[Eval] {
override def map[A, B](fa: Eval[A])(f: A => B): Eval[B] = fa.map(f)
def pure[A](a: A): Eval[A] = Now(a)
override def pureEval[A](la: Eval[A]): Eval[A] = la
def flatMap[A, B](fa: Eval[A])(f: A => Eval[B]): Eval[B] = fa.flatMap(f)
def extract[A](la: Eval[A]): A = la.value
def coflatMap[A, B](fa: Eval[A])(f: Eval[A] => B): Eval[B] = Later(f(fa))
def tailRecM[A, B](a: A)(f: A => Eval[A Xor B]): Eval[B] =
f(a).flatMap(_ match {
case Xor.Left(a1) => tailRecM(a1)(f) // recursion OK here, since flatMap is lazy
case Xor.Right(b) => Eval.now(b)
})
}

implicit def evalOrder[A: Order]: Order[Eval[A]] =
Expand Down
52 changes: 36 additions & 16 deletions core/src/main/scala/cats/data/StateT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -135,22 +135,9 @@ object StateT extends StateTInstances {
StateT(s => F.pure((s, a)))
}

private[data] sealed abstract class StateTInstances {
implicit def catsDataMonadStateForStateT[F[_], S](implicit F: Monad[F]): MonadState[StateT[F, S, ?], S] =
new MonadState[StateT[F, S, ?], S] {
def pure[A](a: A): StateT[F, S, A] =
StateT.pure(a)

def flatMap[A, B](fa: StateT[F, S, A])(f: A => StateT[F, S, B]): StateT[F, S, B] =
fa.flatMap(f)

val get: StateT[F, S, S] = StateT(a => F.pure((a, a)))

def set(s: S): StateT[F, S, Unit] = StateT(_ => F.pure((s, ())))

override def map[A, B](fa: StateT[F, S, A])(f: A => B): StateT[F, S, B] =
fa.map(f)
}
private[data] sealed abstract class StateTInstances extends StateTInstances1 {
implicit def catsDataMonadStateForStateT[F[_], S](implicit F0: Monad[F]): MonadState[StateT[F, S, ?], S] =
new StateTMonadState[F, S] { implicit def F = F0 }

implicit def catsDataLiftForStateT[S]: TransLift.Aux[StateT[?[_], S, ?], Applicative] =
new TransLift[StateT[?[_], S, ?]] {
Expand All @@ -161,6 +148,11 @@ private[data] sealed abstract class StateTInstances {

}

private[data] sealed abstract class StateTInstances1 {
implicit def catsDataMonadRecForStateT[F[_], S](implicit F0: MonadRec[F]): MonadRec[StateT[F, S, ?]] =
new StateTMonadRec[F, S] { implicit def F = F0 }
}

// To workaround SI-7139 `object State` needs to be defined inside the package object
// together with the type alias.
private[data] abstract class StateFunctions {
Expand Down Expand Up @@ -193,3 +185,31 @@ private[data] abstract class StateFunctions {
*/
def set[S](s: S): State[S, Unit] = State(_ => (s, ()))
}

private[data] sealed trait StateTMonad[F[_], S] extends Monad[StateT[F, S, ?]] {
implicit def F: Monad[F]

def pure[A](a: A): StateT[F, S, A] =
StateT.pure(a)

def flatMap[A, B](fa: StateT[F, S, A])(f: A => StateT[F, S, B]): StateT[F, S, B] =
fa.flatMap(f)

override def map[A, B](fa: StateT[F, S, A])(f: A => B): StateT[F, S, B] =
fa.map(f)
}

private[data] sealed trait StateTMonadState[F[_], S] extends MonadState[StateT[F, S, ?], S] with StateTMonad[F, S] {
val get: StateT[F, S, S] = StateT(s => F.pure((s, s)))

def set(s: S): StateT[F, S, Unit] = StateT(_ => F.pure((s, ())))
}

private[data] sealed trait StateTMonadRec[F[_], S] extends MonadRec[StateT[F, S, ?]] with StateTMonad[F, S] {
override implicit def F: MonadRec[F]

def tailRecM[A, B](a: A)(f: A => StateT[F, S, A Xor B]): StateT[F, S, B] =
StateT[F, S, B](s => F.tailRecM[(S, A), (S, B)]((s, a)) {
case (s, a) => F.map(f(a).run(s)) { case (s, ab) => ab.bimap((s, _), (s, _)) }
})
}
4 changes: 3 additions & 1 deletion tests/src/test/scala/cats/tests/EvalTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package tests

import scala.math.min
import cats.laws.ComonadLaws
import cats.laws.discipline.{CartesianTests, BimonadTests, SerializableTests}
import cats.laws.discipline.{BimonadTests, CartesianTests, MonadRecTests, SerializableTests}
import cats.laws.discipline.arbitrary._
import cats.kernel.laws.{GroupLaws, OrderLaws}

Expand Down Expand Up @@ -93,8 +93,10 @@ class EvalTests extends CatsSuite {
{
implicit val iso = CartesianTests.Isomorphisms.invariant[Eval]
checkAll("Eval[Int]", BimonadTests[Eval].bimonad[Int, Int, Int])
checkAll("Eval[Int]", MonadRecTests[Eval].monadRec[Int, Int, Int])
}
checkAll("Bimonad[Eval]", SerializableTests.serializable(Bimonad[Eval]))
checkAll("MonadRec[Eval]", SerializableTests.serializable(MonadRec[Eval]))

checkAll("Eval[Int]", GroupLaws[Eval[Int]].group)

Expand Down
11 changes: 10 additions & 1 deletion tests/src/test/scala/cats/tests/MonadRecInstancesTests.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package cats
package tests

import cats.data.{OptionT, Xor, XorT}
import cats.data.{OptionT, StateT, Xor, XorT}

class MonadRecInstancesTests extends CatsSuite {
def tailRecMStackSafety[M[_]](implicit M: MonadRec[M], Eq: Eq[M[Int]]): Unit = {
Expand Down Expand Up @@ -38,4 +38,13 @@ class MonadRecInstancesTests extends CatsSuite {
tailRecMStackSafety[List]
}

test("tailRecM stack-safety for Eval") {
tailRecMStackSafety[Eval]
}

test("tailRecM stack-safety for StateT") {
import StateTTests._ // import implicit Eq[StateT[...]]
tailRecMStackSafety[StateT[Option, Int, ?]]
}

}
12 changes: 10 additions & 2 deletions tests/src/test/scala/cats/tests/StateTTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package cats
package tests

import cats.kernel.std.tuple._
import cats.laws.discipline.{CartesianTests, MonadStateTests, SerializableTests}
import cats.laws.discipline.{CartesianTests, MonadRecTests, MonadStateTests, SerializableTests}
import cats.data.{State, StateT}
import cats.laws.discipline.eq._
import cats.laws.discipline.arbitrary._
Expand Down Expand Up @@ -116,14 +116,22 @@ class StateTTests extends CatsSuite {

{
implicit val iso = CartesianTests.Isomorphisms.invariant[StateT[Option, Int, ?]]

checkAll("StateT[Option, Int, Int]", MonadStateTests[StateT[Option, Int, ?], Int].monadState[Int, Int, Int])
checkAll("MonadState[StateT[Option, ?, ?], Int]", SerializableTests.serializable(MonadState[StateT[Option, Int, ?], Int]))
checkAll("MonadState[StateT[Option, Int, ?], Int]", SerializableTests.serializable(MonadState[StateT[Option, Int, ?], Int]))

checkAll("StateT[Option, Int, Int]", MonadRecTests[StateT[Option, Int, ?]].monadRec[Int, Int, Int])
checkAll("MonadRec[StateT[Option, Int, ?]]", SerializableTests.serializable(MonadRec[StateT[Option, Int, ?]]))
}

{
implicit val iso = CartesianTests.Isomorphisms.invariant[State[Long, ?]]

checkAll("State[Long, ?]", MonadStateTests[State[Long, ?], Long].monadState[Int, Int, Int])
checkAll("MonadState[State[Long, ?], Long]", SerializableTests.serializable(MonadState[State[Long, ?], Long]))

checkAll("State[Long, ?]", MonadRecTests[State[Long, ?]].monadRec[Int, Int, Int])
checkAll("MonadRec[State[Long, ?]]", SerializableTests.serializable(MonadRec[State[Long, ?]]))
}
}

Expand Down

0 comments on commit 6c6b2a8

Please sign in to comment.