Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove FlatMapRec make all FlatMap implement tailRecM #1280

Merged
merged 11 commits into from
Aug 12, 2016
4 changes: 2 additions & 2 deletions core/src/main/scala/cats/Eval.scala
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,8 @@ object Eval extends EvalInstances {

private[cats] trait EvalInstances extends EvalInstances0 {

implicit val catsBimonadForEval: Bimonad[Eval] with MonadRec[Eval] =
new Bimonad[Eval] with MonadRec[Eval] {
implicit val catsBimonadForEval: Bimonad[Eval] with Monad[Eval] with RecursiveTailRecM[Eval] =
new Bimonad[Eval] with Monad[Eval] with RecursiveTailRecM[Eval] {
override def map[A, B](fa: Eval[A])(f: A => B): Eval[B] = fa.map(f)
def pure[A](a: A): Eval[A] = Now(a)
def flatMap[A, B](fa: Eval[A])(f: A => Eval[B]): Eval[B] = fa.flatMap(f)
Expand Down
13 changes: 13 additions & 0 deletions core/src/main/scala/cats/FlatMap.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package cats

import cats.data.Xor
import simulacrum.typeclass

/**
Expand Down Expand Up @@ -90,4 +91,16 @@ import simulacrum.typeclass
*/
def ifM[B](fa: F[Boolean])(ifTrue: => F[B], ifFalse: => F[B]): F[B] =
flatMap(fa)(if (_) ifTrue else ifFalse)

/**
* Keeps calling `f` until a `[[cats.data.Xor.Right Right]][B]` is returned.
*
* Based on Phil Freeman's
* [[http://functorial.com/stack-safety-for-free/index.pdf Stack Safety for Free]].
*
* Implementations of this method should ideally use constant stack space. If
* it is constant stack space, an instance of `RecursiveTailRecM[F]` should
* be made available.
*/
def tailRecM[A, B](a: A)(f: A => F[A Xor B]): F[B]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment saying theres a defaultTailRecM that could be used as a stub/template?

}
28 changes: 0 additions & 28 deletions core/src/main/scala/cats/FlatMapRec.scala

This file was deleted.

13 changes: 13 additions & 0 deletions core/src/main/scala/cats/Monad.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package cats

import cats.data.Xor
import simulacrum.typeclass

/**
Expand All @@ -14,4 +15,16 @@ import simulacrum.typeclass
@typeclass trait Monad[F[_]] extends FlatMap[F] with Applicative[F] {
override def map[A, B](fa: F[A])(f: A => B): F[B] =
flatMap(fa)(a => pure(f(a)))

/**
* This is not stack safe if the monad is not trampolined, but it
* is always lawful. It it better if you can find a stack safe way
* to write this method (all cats types have a stack safe version
* of this). When this method is safe you can find an `implicit r: RecursiveTailRecM`.
*/
protected def defaultTailRecM[A, B](a: A)(fn: A => F[A Xor B]): F[B] =
flatMap(fn(a)) {
case Xor.Right(b) => pure(b)
case Xor.Left(nextA) => defaultTailRecM(nextA)(fn)
}
}
5 changes: 0 additions & 5 deletions core/src/main/scala/cats/MonadRec.scala

This file was deleted.

24 changes: 24 additions & 0 deletions core/src/main/scala/cats/RecursiveTailRecM.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package cats

import java.io.Serializable
/**
* This is a marker type that promises that the method
* .tailRecM for this type is stack-safe for arbitrary recursion.
*/
trait RecursiveTailRecM[F[_]] extends Serializable {
/*
* you can call RecusiveTailRecM[F].sameType(Monad[F]).tailRec
* to have a static check that the types agree
* for safer usage of tailRecM
*/
final def sameType[M[_[_]]](m: M[F]): M[F] = m
}

object RecursiveTailRecM {
private[this] val singleton: RecursiveTailRecM[Id] = new RecursiveTailRecM[Id] { }

def apply[F[_]](implicit r: RecursiveTailRecM[F]): RecursiveTailRecM[F] = r

def create[F[_]]: RecursiveTailRecM[F] =
singleton.asInstanceOf[RecursiveTailRecM[F]]
}
13 changes: 12 additions & 1 deletion core/src/main/scala/cats/data/Cokleisli.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package data
import cats.arrow.{Arrow, Split}
import cats.functor.{Contravariant, Profunctor}
import cats.{CoflatMap, Comonad, Functor, Monad}
import scala.annotation.tailrec

/**
* Represents a function `F[A] => B`.
Expand Down Expand Up @@ -47,7 +48,7 @@ private[data] sealed abstract class CokleisliInstances extends CokleisliInstance
implicit def catsDataArrowForCokleisli[F[_]](implicit ev: Comonad[F]): Arrow[Cokleisli[F, ?, ?]] =
new CokleisliArrow[F] { def F: Comonad[F] = ev }

implicit def catsDataMonadForCokleisli[F[_], A]: Monad[Cokleisli[F, A, ?]] = new Monad[Cokleisli[F, A, ?]] {
implicit def catsDataMonadForCokleisli[F[_], A]: Monad[Cokleisli[F, A, ?]] with RecursiveTailRecM[Cokleisli[F, A, ?]] = new Monad[Cokleisli[F, A, ?]] with RecursiveTailRecM[Cokleisli[F, A, ?]] {
def pure[B](x: B): Cokleisli[F, A, B] =
Cokleisli.pure(x)

Expand All @@ -56,6 +57,16 @@ private[data] sealed abstract class CokleisliInstances extends CokleisliInstance

override def map[B, C](fa: Cokleisli[F, A, B])(f: B => C): Cokleisli[F, A, C] =
fa.map(f)

def tailRecM[B, C](b: B)(fn: B => Cokleisli[F, A, B Xor C]): Cokleisli[F, A, C] =
Cokleisli({ (fa: F[A]) =>
@tailrec
def loop(c: Cokleisli[F, A, B Xor C]): C = c.run(fa) match {
case Xor.Right(c) => c
case Xor.Left(bb) => loop(fn(bb))
}
loop(fn(b))
})
}

implicit def catsDataMonoidKForCokleisli[F[_]](implicit ev: Comonad[F]): MonoidK[λ[α => Cokleisli[F, α, α]]] =
Expand Down
6 changes: 6 additions & 0 deletions core/src/main/scala/cats/data/IdT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ private[data] sealed trait IdTMonad[F[_]] extends Monad[IdT[F, ?]] {

def flatMap[A, B](fa: IdT[F, A])(f: A => IdT[F, B]): IdT[F, B] =
fa.flatMap(f)

def tailRecM[A, B](a: A)(f: A => IdT[F, A Xor B]): IdT[F, B] =
IdT(F0.tailRecM(a)(f(_).value))
}

private[data] sealed trait IdTFoldable[F[_]] extends Foldable[IdT[F, ?]] {
Expand Down Expand Up @@ -83,6 +86,9 @@ private[data] sealed abstract class IdTInstances0 extends IdTInstances1 {
implicit val F0: Monad[F] = F
}

implicit def catsDataRecursiveTailRecMForIdT[F[_]: RecursiveTailRecM]: RecursiveTailRecM[IdT[F, ?]] =
RecursiveTailRecM.create[IdT[F, ?]]

implicit def catsDataFoldableForIdT[F[_]](implicit F: Foldable[F]): Foldable[IdT[F, ?]] =
new IdTFoldable[F] {
implicit val F0: Foldable[F] = F
Expand Down
19 changes: 18 additions & 1 deletion core/src/main/scala/cats/data/Ior.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cats
package data

import cats.functor.Bifunctor
import scala.annotation.tailrec

/** Represents a right-biased disjunction that is either an `A`, or a `B`, or both an `A` and a `B`.
*
Expand Down Expand Up @@ -140,9 +141,25 @@ private[data] sealed abstract class IorInstances extends IorInstances0 {
def show(f: A Ior B): String = f.show
}

implicit def catsDataMonadForIor[A: Semigroup]: Monad[A Ior ?] = new Monad[A Ior ?] {
implicit def catsDataMonadForIor[A: Semigroup]: Monad[A Ior ?] with RecursiveTailRecM[A Ior ?] = new Monad[A Ior ?] with RecursiveTailRecM[A Ior ?] {
def pure[B](b: B): A Ior B = Ior.right(b)
def flatMap[B, C](fa: A Ior B)(f: B => A Ior C): A Ior C = fa.flatMap(f)
def tailRecM[B, C](b: B)(fn: B => Ior[A, Xor[B, C]]): A Ior C = {
@tailrec
def loop(v: Ior[A, Xor[B, C]]): A Ior C = v match {
case Ior.Left(a) => Ior.left(a)
case Ior.Right(Xor.Right(c)) => Ior.right(c)
case Ior.Both(a, Xor.Right(c)) => Ior.both(a, c)
case Ior.Right(Xor.Left(b)) => loop(fn(b))
case Ior.Both(a, Xor.Left(b)) =>
fn(b) match {
case Ior.Left(aa) => Ior.left(Semigroup[A].combine(a, aa))
case Ior.Both(aa, x) => loop(Ior.both(Semigroup[A].combine(a, aa), x))
case Ior.Right(x) => loop(Ior.both(a, x))
}
}
loop(fn(b))
}
}

implicit def catsDataBifunctorForIor: Bifunctor[Ior] =
Expand Down
8 changes: 8 additions & 0 deletions core/src/main/scala/cats/data/Kleisli.scala
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,12 @@ private[data] sealed abstract class KleisliInstances0 extends KleisliInstances1

def map[B, C](fa: Kleisli[F, A, B])(f: B => C): Kleisli[F, A, C] =
fa.map(f)

def tailRecM[B, C](b: B)(f: B => Kleisli[F, A, B Xor C]): Kleisli[F, A, C] =
Kleisli[F, A, C]({ a => FlatMap[F].tailRecM(b) { f(_).run(a) } })
}
implicit def catsDataRecursiveTailRecMForKleisli[F[_]: RecursiveTailRecM, A]: RecursiveTailRecM[Kleisli[F, A, ?]] =
RecursiveTailRecM.create[Kleisli[F, A, ?]]

implicit def catsDataSemigroupForKleisli[F[_], A, B](implicit M: Semigroup[F[B]]): Semigroup[Kleisli[F, A, B]] =
new KleisliSemigroup[F, A, B] { def FB: Semigroup[F[B]] = M }
Expand Down Expand Up @@ -194,6 +199,9 @@ private[data] sealed abstract class KleisliInstances4 {

def local[B](f: A => A)(fa: Kleisli[F, A, B]): Kleisli[F, A, B] =
Kleisli(f.andThen(fa.run))

def tailRecM[B, C](b: B)(f: B => Kleisli[F, A, B Xor C]): Kleisli[F, A, C] =
Kleisli[F, A, C]({ a => FlatMap[F].tailRecM(b) { f(_).run(a) } })
}
}

Expand Down
6 changes: 3 additions & 3 deletions core/src/main/scala/cats/data/NonEmptyList.scala
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ object NonEmptyList extends NonEmptyListInstances {
private[data] sealed trait NonEmptyListInstances extends NonEmptyListInstances0 {

implicit val catsDataInstancesForNonEmptyList: SemigroupK[NonEmptyList] with Reducible[NonEmptyList]
with Comonad[NonEmptyList] with Traverse[NonEmptyList] with MonadRec[NonEmptyList] =
new NonEmptyReducible[NonEmptyList, List] with SemigroupK[NonEmptyList]
with Comonad[NonEmptyList] with Traverse[NonEmptyList] with MonadRec[NonEmptyList] {
with Comonad[NonEmptyList] with Traverse[NonEmptyList] with Monad[NonEmptyList] with RecursiveTailRecM[NonEmptyList] =
new NonEmptyReducible[NonEmptyList, List] with SemigroupK[NonEmptyList] with Comonad[NonEmptyList]
with Traverse[NonEmptyList] with Monad[NonEmptyList] with RecursiveTailRecM[NonEmptyList] {

def combineK[A](a: NonEmptyList[A], b: NonEmptyList[A]): NonEmptyList[A] =
a concat b
Expand Down
6 changes: 3 additions & 3 deletions core/src/main/scala/cats/data/NonEmptyVector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ final class NonEmptyVector[A] private (val toVector: Vector[A]) extends AnyVal {
private[data] sealed trait NonEmptyVectorInstances {

implicit val catsDataInstancesForNonEmptyVector: SemigroupK[NonEmptyVector] with Reducible[NonEmptyVector]
with Comonad[NonEmptyVector] with Traverse[NonEmptyVector] with MonadRec[NonEmptyVector] =
new NonEmptyReducible[NonEmptyVector, Vector] with SemigroupK[NonEmptyVector]
with Comonad[NonEmptyVector] with Traverse[NonEmptyVector] with MonadRec[NonEmptyVector] {
with Comonad[NonEmptyVector] with Traverse[NonEmptyVector] with Monad[NonEmptyVector] with RecursiveTailRecM[NonEmptyVector] =
new NonEmptyReducible[NonEmptyVector, Vector] with SemigroupK[NonEmptyVector] with Comonad[NonEmptyVector]
with Traverse[NonEmptyVector] with Monad[NonEmptyVector] with RecursiveTailRecM[NonEmptyVector] {

def combineK[A](a: NonEmptyVector[A], b: NonEmptyVector[A]): NonEmptyVector[A] =
a concatNev b
Expand Down
35 changes: 35 additions & 0 deletions core/src/main/scala/cats/data/OneAnd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,42 @@ private[data] sealed trait OneAndInstances extends OneAndLowPriority2 {
val fst = f(fa.head)
OneAnd(fst.head, monad.combineK(fst.tail, end))
}

def tailRecM[A, B](a: A)(fn: A => OneAnd[F, A Xor B]): OneAnd[F, B] = {
def stepF(a: A): F[A Xor B] = {
val oneAnd = fn(a)
monad.combineK(monad.pure(oneAnd.head), oneAnd.tail)
}
def toFB(in: A Xor B): F[B] = in match {
case Xor.Right(b) => monad.pure(b)
case Xor.Left(a) => monad.tailRecM(a)(stepF)
}

// This could probably be in SemigroupK to perform well
@tailrec
def combineAll(items: List[F[B]]): F[B] = items match {
case Nil => monad.empty
case h :: Nil => h
case h1 :: h2 :: tail => combineAll(monad.combineK(h1, h2) :: tail)
}

@tailrec
def go(in: A, rest: List[F[B]]): OneAnd[F, B] =
fn(in) match {
case OneAnd(Xor.Right(b), tail) =>
val fbs = monad.flatMap(tail)(toFB)
OneAnd(b, combineAll(fbs :: rest))
case OneAnd(Xor.Left(a), tail) =>
val fbs = monad.flatMap(tail)(toFB)
go(a, fbs :: rest)
}

go(a, Nil)
}
}

implicit def catsDataOneAnd[F[_]: RecursiveTailRecM]: RecursiveTailRecM[OneAnd[F, ?]] =
RecursiveTailRecM.create[OneAnd[F, ?]]
}

private[data] trait OneAndLowPriority0 {
Expand Down
13 changes: 4 additions & 9 deletions core/src/main/scala/cats/data/OptionT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ object OptionT extends OptionTInstances {
}

private[data] sealed trait OptionTInstances extends OptionTInstances0 {
implicit def catsDataMonadRecForOptionT[F[_]](implicit F0: MonadRec[F]): MonadRec[OptionT[F, ?]] =
new OptionTMonadRec[F] { implicit val F = F0 }
implicit def catsDataMonadForOptionT[F[_]](implicit F0: Monad[F]): Monad[OptionT[F, ?]] =
new OptionTMonad[F] { implicit val F = F0 }

implicit def catsDataFoldableForOptionT[F[_]](implicit F0: Foldable[F]): Foldable[OptionT[F, ?]] =
new OptionTFoldable[F] { implicit val F = F0 }
Expand All @@ -195,6 +195,8 @@ private[data] sealed trait OptionTInstances0 extends OptionTInstances1 {
implicit def catsDataMonadErrorForOptionT[F[_], E](implicit F0: MonadError[F, E]): MonadError[OptionT[F, ?], E] =
new OptionTMonadError[F, E] { implicit val F = F0 }

implicit def catsDataRecursiveTailRecM[F[_]](implicit F: RecursiveTailRecM[F]): RecursiveTailRecM[OptionT[F, ?]] = RecursiveTailRecM.create[OptionT[F, ?]]

implicit def catsDataSemigroupKForOptionT[F[_]](implicit F0: Monad[F]): SemigroupK[OptionT[F, ?]] =
new OptionTSemigroupK[F] { implicit val F = F0 }

Expand All @@ -206,9 +208,6 @@ private[data] sealed trait OptionTInstances0 extends OptionTInstances1 {
}

private[data] sealed trait OptionTInstances1 extends OptionTInstances2 {
implicit def catsDataMonadForOptionT[F[_]](implicit F0: Monad[F]): Monad[OptionT[F, ?]] =
new OptionTMonad[F] { implicit val F = F0 }

// do NOT change this to val! I know it looks like it should work, and really I agree, but it doesn't (for... reasons)
implicit def catsDataTransLiftForOptionT: TransLift.Aux[OptionT, Functor] =
new TransLift[OptionT] {
Expand Down Expand Up @@ -251,10 +250,6 @@ private[data] trait OptionTMonad[F[_]] extends Monad[OptionT[F, ?]] {
def flatMap[A, B](fa: OptionT[F, A])(f: A => OptionT[F, B]): OptionT[F, B] = fa.flatMap(f)

override def map[A, B](fa: OptionT[F, A])(f: A => B): OptionT[F, B] = fa.map(f)
}

private[data] trait OptionTMonadRec[F[_]] extends MonadRec[OptionT[F, ?]] with OptionTMonad[F] {
implicit def F: MonadRec[F]

def tailRecM[A, B](a: A)(f: A => OptionT[F, A Xor B]): OptionT[F, B] =
OptionT(F.tailRecM(a)(a0 => F.map(f(a0).value)(
Expand Down
23 changes: 8 additions & 15 deletions core/src/main/scala/cats/data/StateT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -159,19 +159,16 @@ private[data] sealed trait StateTInstances extends StateTInstances1 {
}

private[data] sealed trait StateTInstances1 extends StateTInstances2 {
implicit def catsDataMonadRecForStateT[F[_], S](implicit F0: MonadRec[F]): MonadRec[StateT[F, S, ?]] =
new StateTMonadRec[F, S] { implicit def F = F0 }
}

private[data] sealed trait StateTInstances2 extends StateTInstances3 {
implicit def catsDataMonadCombineForStateT[F[_], S](implicit F0: MonadCombine[F]): MonadCombine[StateT[F, S, ?]] =
new StateTMonadCombine[F, S] { implicit def F = F0 }
}

private[data] sealed trait StateTInstances3 {
private[data] sealed trait StateTInstances2 {
implicit def catsDataMonadForStateT[F[_], S](implicit F0: Monad[F]): Monad[StateT[F, S, ?]] =
new StateTMonad[F, S] { implicit def F = F0 }

implicit def catsDataRecursiveTailRecMForStateT[F[_]: RecursiveTailRecM, S]: RecursiveTailRecM[StateT[F, S, ?]] = RecursiveTailRecM.create[StateT[F, S, ?]]

implicit def catsDataSemigroupKForStateT[F[_], S](implicit F0: Monad[F], G0: SemigroupK[F]): SemigroupK[StateT[F, S, ?]] =
new StateTSemigroupK[F, S] { implicit def F = F0; implicit def G = G0 }
}
Expand Down Expand Up @@ -219,6 +216,11 @@ private[data] sealed trait StateTMonad[F[_], S] extends Monad[StateT[F, S, ?]] {
fa.flatMap(f)

override def map[A, B](fa: StateT[F, S, A])(f: A => B): StateT[F, S, B] = fa.map(f)

def tailRecM[A, B](a: A)(f: A => StateT[F, S, A Xor B]): StateT[F, S, B] =
StateT[F, S, B](s => F.tailRecM[(S, A), (S, B)]((s, a)) {
case (s, a) => F.map(f(a).run(s)) { case (s, ab) => ab.bimap((s, _), (s, _)) }
})
}

private[data] sealed trait StateTMonadState[F[_], S] extends MonadState[StateT[F, S, ?], S] with StateTMonad[F, S] {
Expand All @@ -227,15 +229,6 @@ private[data] sealed trait StateTMonadState[F[_], S] extends MonadState[StateT[F
def set(s: S): StateT[F, S, Unit] = StateT(_ => F.pure((s, ())))
}

private[data] sealed trait StateTMonadRec[F[_], S] extends MonadRec[StateT[F, S, ?]] with StateTMonad[F, S] {
override implicit def F: MonadRec[F]

def tailRecM[A, B](a: A)(f: A => StateT[F, S, A Xor B]): StateT[F, S, B] =
StateT[F, S, B](s => F.tailRecM[(S, A), (S, B)]((s, a)) {
case (s, a) => F.map(f(a).run(s)) { case (s, ab) => ab.bimap((s, _), (s, _)) }
})
}

private[data] sealed trait StateTTransLift[S] extends TransLift[StateT[?[_], S, ?]] {
type TC[M[_]] = Applicative[M]

Expand Down
Loading