Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ContT monad #2506

Merged
merged 5 commits into from
Nov 16, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions core/src/main/scala/cats/data/ContT.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package cats
package data

/**
* This is a continuation transformer based on the ContT in
* the Haskell package Control.Monad.Cont
*
* This is reasonably straight-forward except that to make
* a tailRecM implementation we leverage the Defer type class to
* obtain stack-safety.
*/
sealed abstract class ContT[M[_], A, +B] extends Serializable {
final def run: (B => M[A]) => M[A] = runAndThen
protected def runAndThen: AndThen[B => M[A], M[A]]

final def map[C](fn: B => C)(implicit M: Defer[M]): ContT[M, A, C] = {
// allocate/pattern match once
val fnAndThen = AndThen(fn)
ContT { fn2 =>
val cb = fnAndThen.andThen(fn2)
M.defer(run(cb))
}
}

/**
* c.mapCont(f).run(g) == f(c.run(g))
*/
final def mapCont(fn: M[A] => M[A]): ContT[M, A, B] =
// Use later here to avoid forcing run
ContT.later(runAndThen.andThen(fn))

/**
* cont.withCont(f).run(cb) == cont.run(f(cb))
*/
final def withCont[C](fn: (C => M[A]) => B => M[A]): ContT[M, A, C] =
// lazy to avoid forcing run
ContT.later(AndThen(fn).andThen(runAndThen))

final def flatMap[C](fn: B => ContT[M, A, C])(implicit M: Defer[M]): ContT[M, A, C] = {
// allocate/pattern match once
val fnAndThen = AndThen(fn)
ContT[M, A, C] { fn2 =>
val contRun: ContT[M, A, C] => M[A] = (_.run(fn2))
val fn3: B => M[A] = fnAndThen.andThen(contRun)
M.defer(run(fn3))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I stand very much corrected here. I'm convinced that this is stack-safe. I owe you an apology, @johnynek. Dear internets, please link here if you're looking for a handy example of me being wrong.

When I first reviewed this PR on my phone, it looked like a slightly different version of the trick you used in #1400, but this is considerably better. You're basically inheriting the stack-safety of the underlying monad by explicitly hitting a stack-safe join at this point and returning to its trampoline (which it must have due to Defer). I'm not sure you even need the AndThen, since the stack is cut by the defer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We all make mistakes here and there. I very much admire folks who feel no shame about it. Kudos to you.

}
}
}

object ContT {

// Note, we only have two instances of ContT in order to be gentle on the JVM JIT
// which treats classes with more than two subclasses differently
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just make ContT take runAndThen as a parameter to get the same benefit without the subtype limit. :-) A more impactful micro-optimization would be to make the implementations within ContT marked as final, since then it won't matter how many subtypes you have.


private case class FromFn[M[_], A, B](runAndThen: AndThen[B => M[A], M[A]]) extends ContT[M, A, B]

private case class DeferCont[M[_], A, B](next: () => ContT[M, A, B]) extends ContT[M, A, B] {
@annotation.tailrec
private def loop(n: () => ContT[M, A, B]): ContT[M, A, B] =
n() match {
case DeferCont(n) => loop(n)
case notDefer => notDefer
}

lazy val runAndThen: AndThen[B => M[A], M[A]] = loop(next).runAndThen
}

def pure[M[_], A, B](b: B): ContT[M, A, B] =
apply { cb => cb(b) }

def apply[M[_], A, B](fn: (B => M[A]) => M[A]): ContT[M, A, B] =
FromFn(AndThen(fn))
Copy link
Member

@djspiewak djspiewak Sep 18, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd also like to see an applyPure or something like that defined as:

def applyPure[M[_]: Applicative, A, B](fn: (B => A) => A): ContT[M, A, B] =
  apply[M, A, B](_(fn.andThen(_.pure)).pure)

Or thereabouts…

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, I don't think we can do that I think we need Applicative and Comonad (we need to go B => M[A] to B => A to call fn.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see what you mean. That's definitely unfortunate. I'm not sure the function would be useful at all requiring Comonad


def later[M[_], A, B](fn: => (B => M[A]) => M[A]): ContT[M, A, B] =
DeferCont(() => FromFn(AndThen(fn)))

def tailRecM[M[_], A, B, C](a: A)(fn: A => ContT[M, C, Either[A, B]])(implicit M: Defer[M]): ContT[M, C, B] =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can just delegate to flatMap now. Probably slightly slower than this though.

ContT[M, C, B] { cb: (B => M[C]) =>

def go(a: A): M[C] =
fn(a).run {
case Left(a) => M.defer(go(a))
case Right(b) => M.defer(cb(b))
}

go(a)
}

implicit def catsDataContTDefer[M[_], B]: Defer[ContT[M, B, ?]] =
new Defer[ContT[M, B, ?]] {
def defer[A](c: => ContT[M, B, A]): ContT[M, B, A] =
DeferCont(() => c)
}

implicit def catsDataContTMonad[M[_]: Defer, A]: Monad[ContT[M, A, ?]] =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Random thought…

What if we implicitly prioritize a bit here? ContT is still useful with very short bind chains even when the underlying monad is not Defer, it just wouldn't be stack-safe. So make the higher priority Monad instance a Monad with Defer and require the Defer[M] constraint. Provide a lower priority instance which doesn't require or provide the Defer[M] (and thus also has a stack-unsafe tailRecM). Obviously this means removing the flatMap and map implementations from the ContT class itself and only making them visible via implicit enrichment (if the imports are a problem for usability, we can with MonadSyntax on the companion object).

The advantage would be a more broadly applicable ContT (admittedly, I'm not sure how much people care about continuations on the JVM outside of async effects, so… maybe not a big advantage?) which transparently preserves stack-safety whenever possible, and gracefully degrades whenever not. Maybe this is taking the "tailRecM doesn't really need to be stack-safe to be lawful" idea too far though.

new Monad[ContT[M, A, ?]] {
def pure[B](b: B): ContT[M, A, B] =
ContT.pure(b)

override def map[B, C](c: ContT[M, A, B])(fn: B => C): ContT[M, A, C] =
c.map(fn)

def flatMap[B, C](c: ContT[M, A, B])(fn: B => ContT[M, A, C]): ContT[M, A, C] =
c.flatMap(fn)

def tailRecM[B, C](b: B)(fn: B => ContT[M, A, Either[B, C]]): ContT[M, A, C] =
ContT.tailRecM(b)(fn)
}
}
69 changes: 69 additions & 0 deletions tests/src/test/scala/cats/tests/ContTSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package cats
package tests

import cats.data.ContT
import cats.laws.discipline._
import cats.laws.discipline.arbitrary._
import org.scalacheck.{Arbitrary, Gen}

class ContTSuite extends CatsSuite {

implicit def arbContT[M[_], A, B](implicit arbFn: Arbitrary[(B => M[A]) => M[A]]): Arbitrary[ContT[M, A, B]] =
Arbitrary(arbFn.arbitrary.map(ContT[M, A, B](_)))

implicit def eqContT[M[_], A, B](implicit arbFn: Arbitrary[B => M[A]], eqMA: Eq[M[A]]): Eq[ContT[M, A, B]] = {
val genItems = Gen.listOfN(100, arbFn.arbitrary)
val fns = genItems.sample.get
new Eq[ContT[M, A, B]] {
def eqv(a: ContT[M, A, B], b: ContT[M, A, B]) =
fns.forall { fn =>
eqMA.eqv(a.run(fn), b.run(fn))
}
}
}

checkAll("ContT[Function0, Int, ?]", MonadTests[ContT[Function0, Int, ?]].monad[Int, String, Int])
checkAll("ContT[Eval, Int, ?]", MonadTests[ContT[Eval, Int, ?]].monad[Int, String, Int])
checkAll("ContT[Function0, Int, ?]", DeferTests[ContT[Function0, Int, ?]].defer[Int])

/**
* c.mapCont(f).run(g) == f(c.run(g))
*/
def mapContLaw[M[_], A, B](
implicit eqma: Eq[M[A]],
cArb: Arbitrary[ContT[M, A, B]],
fnArb: Arbitrary[M[A] => M[A]],
gnArb: Arbitrary[B => M[A]]) =
forAll { (cont: ContT[M, A, B], fn: M[A] => M[A], gn: B => M[A]) =>
assert(eqma.eqv(cont.mapCont(fn).run(gn), fn(cont.run(gn))))
}

/**
* cont.withCont(f).run(g) == cont.run(f(g))
*/
def withContLaw[M[_], A, B, C](
implicit eqma: Eq[M[A]],
cArb: Arbitrary[ContT[M, A, B]],
fnArb: Arbitrary[(C => M[A]) => B => M[A]],
gnArb: Arbitrary[C => M[A]]) =
forAll { (cont: ContT[M, A, B], fn: (C => M[A]) => B => M[A], gn: C => M[A]) =>
assert(eqma.eqv(cont.withCont(fn).run(gn), cont.run(fn(gn))))
}

test("ContT.mapContLaw[Function0, Int, String]") {
mapContLaw[Function0, Int, String]
}

test("ContT.mapContLaw[Eval, Int, String]") {
mapContLaw[Eval, Int, String]
}

test("ContT.withContLaw[Function0, Int, String, Int]") {
withContLaw[Function0, Int, String, Int]
}

test("ContT.withContLaw[Eval, Int, String, Int]") {
withContLaw[Eval, Int, String, Int]
}

}