From f1de85a63f52602839ac96a95ffc3c32f6967ebd Mon Sep 17 00:00:00 2001 From: Fabio Labella Date: Wed, 6 Dec 2017 16:22:01 +0000 Subject: [PATCH] Add `MonadError.rethrow` (#2061) * Add `MonadError.rethrow` * change handleErrorWith to take Both as valid case * fix build * fix doctest --- build.sbt | 5 +++-- core/src/main/scala/cats/MonadError.scala | 20 +++++++++++++++++++ core/src/main/scala/cats/data/Ior.scala | 3 +-- core/src/main/scala/cats/data/IorT.scala | 3 +-- .../main/scala/cats/syntax/monadError.scala | 7 +++++++ .../main/scala/cats/laws/MonadErrorLaws.scala | 3 +++ .../laws/discipline/MonadErrorTests.scala | 3 ++- 7 files changed, 37 insertions(+), 7 deletions(-) diff --git a/build.sbt b/build.sbt index 49686f5347..8fe41bc9c0 100644 --- a/build.sbt +++ b/build.sbt @@ -313,8 +313,9 @@ def mimaSettings(moduleName: String) = Seq( exclude[DirectMissingMethodProblem]("cats.data.ValidatedApplicative.sequence"), exclude[ReversedMissingMethodProblem]("cats.data.IorFunctions.fromEither"), exclude[DirectMissingMethodProblem]("cats.data.RWSTAlternative.traverse"), - exclude[DirectMissingMethodProblem]("cats.data.RWSTAlternative.sequence") - + exclude[DirectMissingMethodProblem]("cats.data.RWSTAlternative.sequence"), + exclude[ReversedMissingMethodProblem]("cats.MonadError.rethrow"), + exclude[ReversedMissingMethodProblem]("cats.syntax.MonadErrorSyntax.catsSyntaxMonadErrorRethrow") ) } ) diff --git a/core/src/main/scala/cats/MonadError.scala b/core/src/main/scala/cats/MonadError.scala index 12348002c9..f861f82803 100644 --- a/core/src/main/scala/cats/MonadError.scala +++ b/core/src/main/scala/cats/MonadError.scala @@ -41,6 +41,26 @@ trait MonadError[F[_], E] extends ApplicativeError[F, E] with Monad[F] { */ def adaptError[A](fa: F[A])(pf: PartialFunction[E, E]): F[A] = flatMap(attempt(fa))(_.fold(e => raiseError(pf.applyOrElse[E, E](e, _ => e)), pure)) + + /** + * Inverse of `attempt` + * + * Example: + * {{{ + * scala> import cats.implicits._ + * scala> import scala.util.{Try, Success} + * + * scala> val a: Try[Either[Throwable, Int]] = Success(Left(new java.lang.Exception)) + * scala> a.rethrow + * res0: scala.util.Try[Int] = Failure(java.lang.Exception) + * + * scala> val b: Try[Either[Throwable, Int]] = Success(Right(1)) + * scala> b.rethrow + * res1: scala.util.Try[Int] = Success(1) + * }}} + */ + def rethrow[A](fa: F[Either[E, A]]): F[A] = + flatMap(fa)(_.fold(raiseError, pure)) } object MonadError { diff --git a/core/src/main/scala/cats/data/Ior.scala b/core/src/main/scala/cats/data/Ior.scala index 12109c4627..04e3b64621 100644 --- a/core/src/main/scala/cats/data/Ior.scala +++ b/core/src/main/scala/cats/data/Ior.scala @@ -165,8 +165,7 @@ private[data] sealed abstract class IorInstances extends IorInstances0 { def handleErrorWith[B](fa: Ior[A, B])(f: (A) => Ior[A, B]): Ior[A, B] = fa match { case Ior.Left(e) => f(e) - case r @ Ior.Right(_) => r - case Ior.Both(e, _) => f(e) + case _ => fa } def flatMap[B, C](fa: Ior[A, B])(f: B => Ior[A, C]): Ior[A, C] = fa.flatMap(f) diff --git a/core/src/main/scala/cats/data/IorT.scala b/core/src/main/scala/cats/data/IorT.scala index 01285647ba..368568990a 100644 --- a/core/src/main/scala/cats/data/IorT.scala +++ b/core/src/main/scala/cats/data/IorT.scala @@ -482,8 +482,7 @@ private[data] sealed trait IorTMonadError[F[_], A] extends MonadError[IorT[F, A, override def handleErrorWith[B](iort: IorT[F, A, B])(f: A => IorT[F, A, B]): IorT[F, A, B] = IorT(F0.flatMap(iort.value) { case Ior.Left(a) => f(a).value - case r @ Ior.Right(_) => F0.pure(r) - case Ior.Both(a, _) => f(a).value // should a be combined with result ?? + case r @ (Ior.Right(_) | Ior.Both(_, _)) => F0.pure(r) }) } diff --git a/core/src/main/scala/cats/syntax/monadError.scala b/core/src/main/scala/cats/syntax/monadError.scala index 95a4f07d53..75b4dbe001 100644 --- a/core/src/main/scala/cats/syntax/monadError.scala +++ b/core/src/main/scala/cats/syntax/monadError.scala @@ -4,6 +4,9 @@ package syntax trait MonadErrorSyntax { implicit final def catsSyntaxMonadError[F[_], E, A](fa: F[A])(implicit F: MonadError[F, E]): MonadErrorOps[F, E, A] = new MonadErrorOps(fa) + + implicit final def catsSyntaxMonadErrorRethrow[F[_], E, A](fea: F[Either[E, A]])(implicit F: MonadError[F, E]): MonadErrorRethrowOps[F, E, A] = + new MonadErrorRethrowOps(fea) } final class MonadErrorOps[F[_], E, A](val fa: F[A]) extends AnyVal { @@ -16,3 +19,7 @@ final class MonadErrorOps[F[_], E, A](val fa: F[A]) extends AnyVal { def adaptError(pf: PartialFunction[E, E])(implicit F: MonadError[F, E]): F[A] = F.adaptError(fa)(pf) } + +final class MonadErrorRethrowOps[F[_], E, A](val fea: F[Either[E, A]]) extends AnyVal { + def rethrow(implicit F: MonadError[F, E]): F[A] = F.rethrow(fea) +} diff --git a/laws/src/main/scala/cats/laws/MonadErrorLaws.scala b/laws/src/main/scala/cats/laws/MonadErrorLaws.scala index 763003a2bb..8a6abbd116 100644 --- a/laws/src/main/scala/cats/laws/MonadErrorLaws.scala +++ b/laws/src/main/scala/cats/laws/MonadErrorLaws.scala @@ -19,6 +19,9 @@ trait MonadErrorLaws[F[_], E] extends ApplicativeErrorLaws[F, E] with MonadLaws[ def adaptErrorRaise[A](e: E, f: E => E): IsEq[F[A]] = F.adaptError(F.raiseError[A](e))(PartialFunction(f)) <-> F.raiseError(f(e)) + + def rethrowAttempt[A](fa: F[A]): IsEq[F[A]] = + F.rethrow(F.attempt(fa)) <-> fa } object MonadErrorLaws { diff --git a/laws/src/main/scala/cats/laws/discipline/MonadErrorTests.scala b/laws/src/main/scala/cats/laws/discipline/MonadErrorTests.scala index f63cdfa111..53e23ca65d 100644 --- a/laws/src/main/scala/cats/laws/discipline/MonadErrorTests.scala +++ b/laws/src/main/scala/cats/laws/discipline/MonadErrorTests.scala @@ -42,7 +42,8 @@ trait MonadErrorTests[F[_], E] extends ApplicativeErrorTests[F, E] with MonadTes "monadError ensure consistency" -> forAll(laws.monadErrorEnsureConsistency[A] _), "monadError ensureOr consistency" -> forAll(laws.monadErrorEnsureOrConsistency[A] _), "monadError adaptError pure" -> forAll(laws.adaptErrorPure[A] _), - "monadError adaptError raise" -> forAll(laws.adaptErrorRaise[A] _) + "monadError adaptError raise" -> forAll(laws.adaptErrorRaise[A] _), + "monadError rethrow attempt" -> forAll(laws.rethrowAttempt[A] _) ) } }