Skip to content

Commit

Permalink
Merge pull request #1041 from TomasMikula/MonadRec-instances
Browse files Browse the repository at this point in the history
Some MonadRec instances.
  • Loading branch information
non committed May 31, 2016
2 parents 2ab7b93 + 5a38209 commit d3c64d1
Show file tree
Hide file tree
Showing 22 changed files with 333 additions and 36 deletions.
24 changes: 24 additions & 0 deletions core/src/main/scala/cats/FlatMapRec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package cats

import simulacrum.typeclass

import cats.data.Xor

/**
* Version of [[cats.FlatMap]] capable of stack-safe recursive `flatMap`s.
*
* Based on Phil Freeman's
* [[http://functorial.com/stack-safety-for-free/index.pdf Stack Safety for Free]].
*/
@typeclass trait FlatMapRec[F[_]] extends FlatMap[F] {

/**
* Keeps calling `f` until a `[[cats.data.Xor.Right Right]][B]` is returned.
*
* Implementations of this method must use constant stack space.
*
* `f` must use constant stack space. (It is OK to use a constant number of
* `map`s and `flatMap`s inside `f`.)
*/
def tailRecM[A, B](a: A)(f: A => F[A Xor B]): F[B]
}
5 changes: 5 additions & 0 deletions core/src/main/scala/cats/MonadRec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package cats

import simulacrum.typeclass

@typeclass trait MonadRec[F[_]] extends Monad[F] with FlatMapRec[F]
42 changes: 31 additions & 11 deletions core/src/main/scala/cats/data/OptionT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ object OptionT extends OptionTInstances {
def liftF[F[_], A](fa: F[A])(implicit F: Functor[F]): OptionT[F, A] = OptionT(F.map(fa)(Some(_)))
}

private[data] sealed trait OptionTInstances1 {
private[data] sealed trait OptionTInstances2 {
implicit def catsDataFunctorForOptionT[F[_]:Functor]: Functor[OptionT[F, ?]] =
new Functor[OptionT[F, ?]] {
override def map[A, B](fa: OptionT[F, A])(f: A => B): OptionT[F, B] =
Expand All @@ -148,22 +148,42 @@ private[data] sealed trait OptionTInstances1 {
}
}

private[data] sealed trait OptionTInstances extends OptionTInstances1 {

implicit def catsDataMonadForOptionT[F[_]](implicit F: Monad[F]): Monad[OptionT[F, ?]] =
new Monad[OptionT[F, ?]] {
def pure[A](a: A): OptionT[F, A] = OptionT.pure(a)
private[data] sealed trait OptionTInstances1 extends OptionTInstances2 {

def flatMap[A, B](fa: OptionT[F, A])(f: A => OptionT[F, B]): OptionT[F, B] =
fa.flatMap(f)
implicit def catsDataMonadForOptionT[F[_]](implicit F0: Monad[F]): Monad[OptionT[F, ?]] =
new OptionTMonad[F] { implicit val F = F0 }
}

override def map[A, B](fa: OptionT[F, A])(f: A => B): OptionT[F, B] =
fa.map(f)
}
private[data] sealed trait OptionTInstances extends OptionTInstances1 {
implicit def catsDataMonadRecForOptionT[F[_]](implicit F0: MonadRec[F]): MonadRec[OptionT[F, ?]] =
new OptionTMonadRec[F] { implicit val F = F0 }

implicit def catsDataEqForOptionT[F[_], A](implicit FA: Eq[F[Option[A]]]): Eq[OptionT[F, A]] =
FA.on(_.value)

implicit def catsDataShowForOptionT[F[_], A](implicit F: Show[F[Option[A]]]): Show[OptionT[F, A]] =
functor.Contravariant[Show].contramap(F)(_.value)
}

private[data] trait OptionTMonad[F[_]] extends Monad[OptionT[F, ?]] {
implicit val F: Monad[F]

def pure[A](a: A): OptionT[F, A] = OptionT.pure(a)

def flatMap[A, B](fa: OptionT[F, A])(f: A => OptionT[F, B]): OptionT[F, B] =
fa.flatMap(f)

override def map[A, B](fa: OptionT[F, A])(f: A => B): OptionT[F, B] =
fa.map(f)
}

private[data] trait OptionTMonadRec[F[_]] extends MonadRec[OptionT[F, ?]] with OptionTMonad[F] {
implicit val F: MonadRec[F]

def tailRecM[A, B](a: A)(f: A => OptionT[F, A Xor B]): OptionT[F, B] =
OptionT(F.tailRecM(a)(a0 => F.map(f(a0).value){
case None => Xor.Right(None)
case Some(Xor.Left(a1)) => Xor.Left(a1)
case Some(Xor.Right(b)) => Xor.Right(Some(b))
}))
}
11 changes: 9 additions & 2 deletions core/src/main/scala/cats/data/Xor.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cats
package data

import scala.annotation.tailrec
import scala.reflect.ClassTag
import scala.util.{Failure, Success, Try}

Expand Down Expand Up @@ -233,13 +234,19 @@ private[data] sealed abstract class XorInstances extends XorInstances1 {
}
}

implicit def catsDataInstancesForXor[A]: Traverse[A Xor ?] with MonadError[Xor[A, ?], A] =
new Traverse[A Xor ?] with MonadError[Xor[A, ?], A] {
implicit def catsDataInstancesForXor[A]: Traverse[A Xor ?] with MonadRec[A Xor ?] with MonadError[Xor[A, ?], A] =
new Traverse[A Xor ?] with MonadRec[A Xor ?] with MonadError[Xor[A, ?], A] {
def traverse[F[_]: Applicative, B, C](fa: A Xor B)(f: B => F[C]): F[A Xor C] = fa.traverse(f)
def foldLeft[B, C](fa: A Xor B, c: C)(f: (C, B) => C): C = fa.foldLeft(c)(f)
def foldRight[B, C](fa: A Xor B, lc: Eval[C])(f: (B, Eval[C]) => Eval[C]): Eval[C] = fa.foldRight(lc)(f)
def flatMap[B, C](fa: A Xor B)(f: B => A Xor C): A Xor C = fa.flatMap(f)
def pure[B](b: B): A Xor B = Xor.right(b)
@tailrec def tailRecM[B, C](b: B)(f: B => A Xor (B Xor C)): A Xor C =
f(b) match {
case Xor.Left(a) => Xor.Left(a)
case Xor.Right(Xor.Left(b1)) => tailRecM(b1)(f)
case Xor.Right(Xor.Right(c)) => Xor.Right(c)
}
def handleErrorWith[B](fea: Xor[A, B])(f: A => Xor[A, B]): Xor[A, B] =
fea match {
case Xor.Left(e) => f(e)
Expand Down
22 changes: 20 additions & 2 deletions core/src/main/scala/cats/data/XorT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,11 @@ private[data] abstract class XorTInstances1 extends XorTInstances2 {
}

private[data] abstract class XorTInstances2 extends XorTInstances3 {
implicit def catsDataMonadRecForXorT[F[_], L](implicit F0: MonadRec[F]): MonadRec[XorT[F, L, ?]] =
new XorTMonadRec[F, L] { implicit val F = F0 }
}

private[data] abstract class XorTInstances3 extends XorTInstances4 {
implicit def catsDataMonadErrorForXorT[F[_], L](implicit F: Monad[F]): MonadError[XorT[F, L, ?], L] = {
implicit val F0 = F
new XorTMonadError[F, L] { implicit val F = F0 }
Expand All @@ -299,7 +304,7 @@ private[data] abstract class XorTInstances2 extends XorTInstances3 {
}
}

private[data] abstract class XorTInstances3 {
private[data] abstract class XorTInstances4 {
implicit def catsDataFunctorForXorT[F[_], L](implicit F: Functor[F]): Functor[XorT[F, L, ?]] = {
implicit val F0 = F
new XorTFunctor[F, L] { implicit val F = F0 }
Expand All @@ -311,10 +316,13 @@ private[data] trait XorTFunctor[F[_], L] extends Functor[XorT[F, L, ?]] {
override def map[A, B](fa: XorT[F, L, A])(f: A => B): XorT[F, L, B] = fa map f
}

private[data] trait XorTMonadError[F[_], L] extends MonadError[XorT[F, L, ?], L] with XorTFunctor[F, L] {
private[data] trait XorTMonad[F[_], L] extends Monad[XorT[F, L, ?]] with XorTFunctor[F, L] {
implicit val F: Monad[F]
def pure[A](a: A): XorT[F, L, A] = XorT.pure[F, L, A](a)
def flatMap[A, B](fa: XorT[F, L, A])(f: A => XorT[F, L, B]): XorT[F, L, B] = fa flatMap f
}

private[data] trait XorTMonadError[F[_], L] extends MonadError[XorT[F, L, ?], L] with XorTMonad[F, L] {
def handleErrorWith[A](fea: XorT[F, L, A])(f: L => XorT[F, L, A]): XorT[F, L, A] =
XorT(F.flatMap(fea.value) {
case Xor.Left(e) => f(e).value
Expand All @@ -333,6 +341,16 @@ private[data] trait XorTMonadError[F[_], L] extends MonadError[XorT[F, L, ?], L]
fla.recoverWith(pf)
}

private[data] trait XorTMonadRec[F[_], L] extends MonadRec[XorT[F, L, ?]] with XorTMonad[F, L] {
implicit val F: MonadRec[F]
def tailRecM[A, B](a: A)(f: A => XorT[F, L, A Xor B]): XorT[F, L, B] =
XorT(F.tailRecM(a)(a0 => F.map(f(a0).value){
case Xor.Left(l) => Xor.Right(Xor.Left(l))
case Xor.Right(Xor.Left(a1)) => Xor.Left(a1)
case Xor.Right(Xor.Right(b)) => Xor.Right(Xor.Right(b))
}))
}

private[data] trait XorTMonadFilter[F[_], L] extends MonadFilter[XorT[F, L, ?]] with XorTMonadError[F, L] {
implicit val F: Monad[F]
implicit val L: Monoid[L]
Expand Down
11 changes: 9 additions & 2 deletions core/src/main/scala/cats/package.scala
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import scala.annotation.tailrec
import cats.data.Xor

/**
* Symbolic aliases for various types are defined here.
*/
Expand Down Expand Up @@ -26,12 +29,16 @@ package object cats {
* encodes pure unary function application.
*/
type Id[A] = A
implicit val idInstances: Bimonad[Id] with Traverse[Id] =
new Bimonad[Id] with Traverse[Id] {
implicit val idInstances: Bimonad[Id] with MonadRec[Id] with Traverse[Id] =
new Bimonad[Id] with MonadRec[Id] with Traverse[Id] {
def pure[A](a: A): A = a
def extract[A](a: A): A = a
def flatMap[A, B](a: A)(f: A => B): B = f(a)
def coflatMap[A, B](a: A)(f: A => B): B = f(a)
@tailrec def tailRecM[A, B](a: A)(f: A => A Xor B): B = f(a) match {
case Xor.Left(a1) => tailRecM(a1)(f)
case Xor.Right(b) => b
}
override def map[A, B](fa: A)(f: A => B): B = f(fa)
override def ap[A, B](ff: A => B)(fa: A): B = ff(fa)
override def flatten[A](ffa: A): A = ffa
Expand Down
15 changes: 13 additions & 2 deletions core/src/main/scala/cats/std/either.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package cats
package std

import scala.annotation.tailrec
import cats.data.Xor

trait EitherInstances extends EitherInstances1 {
implicit val catsStdBitraverseForEither: Bitraverse[Either] =
new Bitraverse[Either] {
Expand All @@ -23,8 +26,8 @@ trait EitherInstances extends EitherInstances1 {
}
}

implicit def catsStdInstancesForEither[A]: Monad[Either[A, ?]] with Traverse[Either[A, ?]] =
new Monad[Either[A, ?]] with Traverse[Either[A, ?]] {
implicit def catsStdInstancesForEither[A]: MonadRec[Either[A, ?]] with Traverse[Either[A, ?]] =
new MonadRec[Either[A, ?]] with Traverse[Either[A, ?]] {
def pure[B](b: B): Either[A, B] = Right(b)

def flatMap[B, C](fa: Either[A, B])(f: B => Either[A, C]): Either[A, C] =
Expand All @@ -33,6 +36,14 @@ trait EitherInstances extends EitherInstances1 {
override def map[B, C](fa: Either[A, B])(f: B => C): Either[A, C] =
fa.right.map(f)

@tailrec
def tailRecM[B, C](b: B)(f: B => Either[A, B Xor C]): Either[A, C] =
f(b) match {
case Left(a) => Left(a)
case Right(Xor.Left(b1)) => tailRecM(b1)(f)
case Right(Xor.Right(c)) => Right(c)
}

override def map2Eval[B, C, Z](fb: Either[A, B], fc: Eval[Either[A, C]])(f: (B, C) => Z): Eval[Either[A, Z]] =
fb match {
// This should be safe, but we are forced to use `asInstanceOf`,
Expand Down
20 changes: 18 additions & 2 deletions core/src/main/scala/cats/std/list.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ import cats.syntax.show._
import scala.annotation.tailrec
import scala.collection.mutable.ListBuffer

import cats.data.Xor

trait ListInstances extends cats.kernel.std.ListInstances {

implicit val catsStdInstancesForList: Traverse[List] with MonadCombine[List] with CoflatMap[List] =
new Traverse[List] with MonadCombine[List] with CoflatMap[List] {
implicit val catsStdInstancesForList: Traverse[List] with MonadCombine[List] with MonadRec[List] with CoflatMap[List] =
new Traverse[List] with MonadCombine[List] with MonadRec[List] with CoflatMap[List] {

def empty[A]: List[A] = Nil

Expand All @@ -26,6 +28,20 @@ trait ListInstances extends cats.kernel.std.ListInstances {
override def map2[A, B, Z](fa: List[A], fb: List[B])(f: (A, B) => Z): List[Z] =
fa.flatMap(a => fb.map(b => f(a, b)))

def tailRecM[A, B](a: A)(f: A => List[A Xor B]): List[B] = {
val buf = List.newBuilder[B]
@tailrec def go(lists: List[List[A Xor B]]): Unit = lists match {
case (ab :: abs) :: tail => ab match {
case Xor.Right(b) => buf += b; go(abs :: tail)
case Xor.Left(a) => go(f(a) :: abs :: tail)
}
case Nil :: tail => go(tail)
case Nil => ()
}
go(f(a) :: Nil)
buf.result
}

def coflatMap[A, B](fa: List[A])(f: List[A] => B): List[B] = {
@tailrec def loop(buf: ListBuffer[B], as: List[A]): List[B] =
as match {
Expand Down
15 changes: 13 additions & 2 deletions core/src/main/scala/cats/std/option.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package cats
package std

import scala.annotation.tailrec
import cats.data.Xor

trait OptionInstances extends cats.kernel.std.OptionInstances {

implicit val catsStdInstancesForOption: Traverse[Option] with MonadError[Option, Unit] with MonadCombine[Option] with CoflatMap[Option] with Alternative[Option] =
new Traverse[Option] with MonadError[Option, Unit] with MonadCombine[Option] with CoflatMap[Option] with Alternative[Option] {
implicit val catsStdInstancesForOption: Traverse[Option] with MonadError[Option, Unit] with MonadCombine[Option] with MonadRec[Option] with CoflatMap[Option] with Alternative[Option] =
new Traverse[Option] with MonadError[Option, Unit] with MonadCombine[Option] with MonadRec[Option] with CoflatMap[Option] with Alternative[Option] {

def empty[A]: Option[A] = None

Expand All @@ -18,6 +21,14 @@ trait OptionInstances extends cats.kernel.std.OptionInstances {
def flatMap[A, B](fa: Option[A])(f: A => Option[B]): Option[B] =
fa.flatMap(f)

@tailrec
def tailRecM[A, B](a: A)(f: A => Option[A Xor B]): Option[B] =
f(a) match {
case None => None
case Some(Xor.Left(a1)) => tailRecM(a1)(f)
case Some(Xor.Right(b)) => Some(b)
}

override def map2[A, B, Z](fa: Option[A], fb: Option[B])(f: (A, B) => Z): Option[Z] =
fa.flatMap(a => fb.map(b => f(a, b)))

Expand Down
9 changes: 7 additions & 2 deletions free/src/main/scala/cats/free/Free.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,16 @@ object Free {
/**
* `Free[S, ?]` has a monad for any type constructor `S[_]`.
*/
implicit def freeMonad[S[_]]: Monad[Free[S, ?]] =
new Monad[Free[S, ?]] {
implicit def freeMonad[S[_]]: MonadRec[Free[S, ?]] =
new MonadRec[Free[S, ?]] {
def pure[A](a: A): Free[S, A] = Free.pure(a)
override def map[A, B](fa: Free[S, A])(f: A => B): Free[S, B] = fa.map(f)
def flatMap[A, B](a: Free[S, A])(f: A => Free[S, B]): Free[S, B] = a.flatMap(f)
def tailRecM[A, B](a: A)(f: A => Free[S, A Xor B]): Free[S, B] =
f(a).flatMap(_ match {
case Xor.Left(a1) => tailRecM(a1)(f) // recursion OK here, since Free is lazy
case Xor.Right(b) => pure(b)
})
}
}

Expand Down
14 changes: 11 additions & 3 deletions free/src/test/scala/cats/free/FreeTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ package free

import cats.tests.CatsSuite
import cats.arrow.NaturalTransformation
import cats.laws.discipline.{CartesianTests, MonadTests, SerializableTests}
import cats.data.Xor
import cats.laws.discipline.{CartesianTests, MonadRecTests, SerializableTests}
import cats.laws.discipline.arbitrary.function0Arbitrary

import org.scalacheck.{Arbitrary, Gen}
Expand All @@ -14,8 +15,8 @@ class FreeTests extends CatsSuite {

implicit val iso = CartesianTests.Isomorphisms.invariant[Free[Option, ?]]

checkAll("Free[Option, ?]", MonadTests[Free[Option, ?]].monad[Int, Int, Int])
checkAll("Monad[Free[Option, ?]]", SerializableTests.serializable(Monad[Free[Option, ?]]))
checkAll("Free[Option, ?]", MonadRecTests[Free[Option, ?]].monadRec[Int, Int, Int])
checkAll("MonadRec[Free[Option, ?]]", SerializableTests.serializable(MonadRec[Free[Option, ?]]))

test("mapSuspension id"){
forAll { x: Free[List, Int] =>
Expand Down Expand Up @@ -43,6 +44,13 @@ class FreeTests extends CatsSuite {
}
}

test("tailRecM is stack safe") {
val n = 50000
val fa = MonadRec[Free[Option, ?]].tailRecM(0)(i =>
Free.pure[Option, Int Xor Int](if(i < n) Xor.Left(i+1) else Xor.Right(i)))
fa should === (Free.pure[Option, Int](n))
}

ignore("foldMap is stack safe") {
trait FTestApi[A]
case class TB(i: Int) extends FTestApi[Int]
Expand Down
26 changes: 26 additions & 0 deletions laws/src/main/scala/cats/laws/FlatMapRecLaws.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package cats
package laws

import cats.data.Xor
import cats.syntax.flatMap._
import cats.syntax.functor._

/**
* Laws that must be obeyed by any `FlatMapRec`.
*/
trait FlatMapRecLaws[F[_]] extends FlatMapLaws[F] {
implicit override def F: FlatMapRec[F]

def tailRecMConsistentFlatMap[A](a: A, f: A => F[A]): IsEq[F[A]] = {
val bounce = F.tailRecM[(A, Int), A]((a, 1)) { case (a0, i) =>
if(i > 0) f(a0).map(a1 => Xor.left((a1, i-1)))
else f(a0).map(Xor.right)
}
bounce <-> f(a).flatMap(f)
}
}

object FlatMapRecLaws {
def apply[F[_]](implicit ev: FlatMapRec[F]): FlatMapRecLaws[F] =
new FlatMapRecLaws[F] { def F: FlatMapRec[F] = ev }
}
Loading

0 comments on commit d3c64d1

Please sign in to comment.