diff --git a/core/src/main/scala-2.13+/cats/instances/arraySeq.scala b/core/src/main/scala-2.13+/cats/instances/arraySeq.scala index 355664b0be..27e7db072f 100644 --- a/core/src/main/scala-2.13+/cats/instances/arraySeq.scala +++ b/core/src/main/scala-2.13+/cats/instances/arraySeq.scala @@ -222,5 +222,10 @@ private[cats] object ArraySeqInstances { fa.foldRight(Eval.now(G.pure(ArraySeq.untagged.empty[A]))) { case (x, xse) => G.map2Eval(f(x), xse)((b, vec) => if (b) x +: vec else vec) }.value + + override def mapAccumulateFilter[S, A, B](init: S, fa: ArraySeq[A])( + f: (S, A) => (S, Option[B]) + ): (S, ArraySeq[B]) = + StaticMethods.mapAccumulateFilterFromStrictFunctorFilter(init, fa, f)(this) } } diff --git a/core/src/main/scala/cats/TraverseFilter.scala b/core/src/main/scala/cats/TraverseFilter.scala index 2f44b0f7c3..f68b65ac2b 100644 --- a/core/src/main/scala/cats/TraverseFilter.scala +++ b/core/src/main/scala/cats/TraverseFilter.scala @@ -122,6 +122,21 @@ trait TraverseFilter[F[_]] extends FunctorFilter[F] { override def mapFilter[A, B](fa: F[A])(f: A => Option[B]): F[B] = traverseFilter[Id, A, B](fa)(f) + /** + * Like [[mapAccumulate]], but allows `Option` in supplied accumulating function, + * keeping only `Some`s. + * + * Example: + * {{{ + * scala> import cats.syntax.all._ + * scala> val sumAllAndKeepOdd = (s: Int, n: Int) => (s + n, Option.when(n % 2 == 1)(n)) + * scala> List(1, 2, 3, 4).mapAccumulateFilter(0, sumAllAndKeepOdd) + * res1: (Int, List[Int]) = (10, List(1, 3)) + * }}} + */ + def mapAccumulateFilter[S, A, B](init: S, fa: F[A])(f: (S, A) => (S, Option[B])): (S, F[B]) = + traverseFilter(fa)(a => State(s => f(s, a))).run(init).value + /** * Removes duplicate elements from a list, keeping only the first occurrence. */ @@ -184,6 +199,8 @@ object TraverseFilter { typeClassInstance.filterA[G, A](self)(f)(G) def traverseEither[G[_], B, C](f: A => G[Either[C, B]])(g: (A, C) => G[Unit])(implicit G: Monad[G]): G[F[B]] = typeClassInstance.traverseEither[G, A, B, C](self)(f)(g)(G) + def mapAccumulateFilter[S, B](init: S)(f: (S, A) => (S, Option[B])): (S, F[B]) = + typeClassInstance.mapAccumulateFilter[S, A, B](init, self)(f) def ordDistinct(implicit O: Order[A]): F[A] = typeClassInstance.ordDistinct(self) def hashDistinct(implicit H: Hash[A]): F[A] = typeClassInstance.hashDistinct(self) } diff --git a/core/src/main/scala/cats/data/Chain.scala b/core/src/main/scala/cats/data/Chain.scala index 69c0d2da5e..e444a1982a 100644 --- a/core/src/main/scala/cats/data/Chain.scala +++ b/core/src/main/scala/cats/data/Chain.scala @@ -1369,6 +1369,9 @@ sealed abstract private[data] class ChainInstances extends ChainInstances1 { ) .value + override def mapAccumulateFilter[S, A, B](init: S, fa: Chain[A])(f: (S, A) => (S, Option[B])): (S, Chain[B]) = + StaticMethods.mapAccumulateFilterFromStrictFunctorFilter(init, fa, f)(this) + } private[this] val theMonoid: Monoid[Chain[Any]] = new Monoid[Chain[Any]] { diff --git a/core/src/main/scala/cats/instances/StaticMethods.scala b/core/src/main/scala/cats/instances/StaticMethods.scala index 8df598bddd..fa44ab968b 100644 --- a/core/src/main/scala/cats/instances/StaticMethods.scala +++ b/core/src/main/scala/cats/instances/StaticMethods.scala @@ -21,7 +21,7 @@ package cats.instances -import cats.Functor +import cats.{Functor, FunctorFilter} import scala.collection.mutable.Builder @@ -70,4 +70,18 @@ private[cats] object StaticMethods { } } + def mapAccumulateFilterFromStrictFunctorFilter[S, F[_], A, B](init: S, fa: F[A], f: (S, A) => (S, Option[B]))(implicit + ev: FunctorFilter[F] + ): (S, F[B]) = { + var state = init + + val fb = ev.mapFilter(fa) { a => + val (newState, b) = f(state, a) + state = newState + b + } + + (state, fb) + } + } diff --git a/core/src/main/scala/cats/instances/list.scala b/core/src/main/scala/cats/instances/list.scala index f4d01cda86..72d0a9201b 100644 --- a/core/src/main/scala/cats/instances/list.scala +++ b/core/src/main/scala/cats/instances/list.scala @@ -322,5 +322,8 @@ private[instances] trait ListInstancesBinCompat0 { G.map2Eval(f(x), xse)((b, list) => if (b) x :: list else list) ) .value + + override def mapAccumulateFilter[S, A, B](init: S, fa: List[A])(f: (S, A) => (S, Option[B])): (S, List[B]) = + StaticMethods.mapAccumulateFilterFromStrictFunctorFilter(init, fa, f)(this) } } diff --git a/core/src/main/scala/cats/instances/option.scala b/core/src/main/scala/cats/instances/option.scala index 315b91f942..f8777a8911 100644 --- a/core/src/main/scala/cats/instances/option.scala +++ b/core/src/main/scala/cats/instances/option.scala @@ -285,5 +285,11 @@ private[instances] trait OptionInstancesBinCompat0 { case Some(a) => G.map(f(a))(b => if (b) Some(a) else None) } + override def mapAccumulateFilter[S, A, B](init: S, fa: Option[A])(f: (S, A) => (S, Option[B])): (S, Option[B]) = + fa match { + case Some(a) => f(init, a) + case None => (init, None) + } + } } diff --git a/core/src/main/scala/cats/instances/queue.scala b/core/src/main/scala/cats/instances/queue.scala index 1bfbc52af6..00c06c0ddf 100644 --- a/core/src/main/scala/cats/instances/queue.scala +++ b/core/src/main/scala/cats/instances/queue.scala @@ -237,5 +237,8 @@ private object QueueInstances { G.map2Eval(f(x), xse)((b, queue) => if (b) x +: queue else queue) ) .value + + override def mapAccumulateFilter[S, A, B](init: S, fa: Queue[A])(f: (S, A) => (S, Option[B])): (S, Queue[B]) = + StaticMethods.mapAccumulateFilterFromStrictFunctorFilter(init, fa, f)(this) } } diff --git a/core/src/main/scala/cats/instances/vector.scala b/core/src/main/scala/cats/instances/vector.scala index 88db57d2b4..872ec1c42f 100644 --- a/core/src/main/scala/cats/instances/vector.scala +++ b/core/src/main/scala/cats/instances/vector.scala @@ -271,5 +271,8 @@ private[instances] trait VectorInstancesBinCompat0 { G.map2Eval(f(x), xse)((b, vector) => if (b) x +: vector else vector) ) .value + + override def mapAccumulateFilter[S, A, B](init: S, fa: Vector[A])(f: (S, A) => (S, Option[B])): (S, Vector[B]) = + StaticMethods.mapAccumulateFilterFromStrictFunctorFilter(init, fa, f)(this) } } diff --git a/tests/shared/src/test/scala/cats/tests/TraverseFilterSuite.scala b/tests/shared/src/test/scala/cats/tests/TraverseFilterSuite.scala index 92d8131d6d..38f8f38fc7 100644 --- a/tests/shared/src/test/scala/cats/tests/TraverseFilterSuite.scala +++ b/tests/shared/src/test/scala/cats/tests/TraverseFilterSuite.scala @@ -38,6 +38,19 @@ abstract class TraverseFilterSuite[F[_]: TraverseFilter](name: String)(implicit implicit def T: Traverse[F] = implicitly[TraverseFilter[F]].traverse + test(s"TraverseFilter[$name].mapAccumulateFilter") { + forAll { (init: Int, fa: F[Int], fn: ((Int, Int)) => (Int, Option[Int])) => + val lhs = fa.mapAccumulateFilter(init)((s, a) => fn((s, a))) + + val rhs = fa.foldLeft((init, List.empty[Int])) { case ((s1, acc), a) => + val (s2, b) = fn((s1, a)) + (s2, b.fold(acc)(_ :: acc)) + } + + assert(lhs.map(_.toList) === rhs.map(_.reverse)) + } + } + test(s"TraverseFilter[$name].ordDistinct") { forAll { (fa: F[Int]) => fa.ordDistinct.toList === fa.toList.distinct