diff --git a/core/src/main/scala/cats/ApplicativeError.scala b/core/src/main/scala/cats/ApplicativeError.scala index a64affd5cb6..ae71da6c3b5 100644 --- a/core/src/main/scala/cats/ApplicativeError.scala +++ b/core/src/main/scala/cats/ApplicativeError.scala @@ -2,6 +2,8 @@ package cats import cats.data.{EitherT, Validated} import cats.data.Validated.{Invalid, Valid} +import cats.ApplicativeError.CatchOnlyPartiallyApplied +import cats.data.EitherT import scala.reflect.ClassTag import scala.util.{Failure, Success, Try} @@ -157,8 +159,6 @@ trait ApplicativeError[F[_], E] extends Applicative[F] { * @param fa is the source whose result is going to get transformed * @param recover is the function that gets called to recover the source * in case of error - * @param map is the function that gets to transform the source - * in case of success */ def redeem[A, B](fa: F[A])(recover: E => B, f: A => B): F[B] = handleError(map(fa)(f))(recover) @@ -216,6 +216,12 @@ trait ApplicativeError[F[_], E] extends Applicative[F] { case NonFatal(e) => raiseError(e) } + /** + * Evaluates the specified block, catching exceptions of the specified type. Uncaught exceptions are propagated. + */ + def catchOnly[T >: Null <: Throwable]: CatchOnlyPartiallyApplied[T, F, E] = + new CatchOnlyPartiallyApplied[T, F, E](this) + /** * If the error type is Throwable, we can convert from a scala.util.Try */ @@ -301,6 +307,17 @@ object ApplicativeError { } } + final private[cats] class CatchOnlyPartiallyApplied[T, F[_], E](private val F: ApplicativeError[F, E]) + extends AnyVal { + def apply[A](f: => A)(implicit CT: ClassTag[T], NT: NotNull[T], ev: Throwable <:< E): F[A] = + try { + F.pure(f) + } catch { + case t if CT.runtimeClass.isInstance(t) => + F.raiseError(t) + } + } + /** * lift from scala.Option[A] to a F[A] * diff --git a/core/src/main/scala/cats/instances/try.scala b/core/src/main/scala/cats/instances/try.scala index b1b571b0066..3e98c024fe1 100644 --- a/core/src/main/scala/cats/instances/try.scala +++ b/core/src/main/scala/cats/instances/try.scala @@ -137,6 +137,10 @@ trait TryInstances extends TryInstances1 { } override def isEmpty[A](fa: Try[A]): Boolean = fa.isFailure + + override def catchNonFatal[A](a: => A)(implicit ev: Throwable <:< Throwable): Try[A] = Try(a) + + override def catchNonFatalEval[A](a: Eval[A])(implicit ev: Throwable <:< Throwable): Try[A] = Try(a.value) } // scalastyle:on method.length diff --git a/tests/src/test/scala/cats/tests/TrySuite.scala b/tests/src/test/scala/cats/tests/TrySuite.scala index f59ad050a70..ef3d698c55e 100644 --- a/tests/src/test/scala/cats/tests/TrySuite.scala +++ b/tests/src/test/scala/cats/tests/TrySuite.scala @@ -61,6 +61,17 @@ class TrySuite extends CatsSuite { res should not be (null) } } + + test("catchOnly works") { + forAll { e: Either[String, Int] => + val str = e.fold(identity, _.toString) + val res = MonadError[Try, Throwable].catchOnly[NumberFormatException](str.toInt) + // the above should just never cause an uncaught exception + // this is a somewhat bogus test: + res should not be (null) + } + } + test("fromTry works") { forAll { t: Try[Int] => (MonadError[Try, Throwable].fromTry(t)) should ===(t)