Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TraverseFilter.mapAccumulateFilter #4561

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
5 changes: 5 additions & 0 deletions core/src/main/scala-2.13+/cats/instances/arraySeq.scala
Original file line number Diff line number Diff line change
Expand Up @@ -222,5 +222,10 @@ private[cats] object ArraySeqInstances {
fa.foldRight(Eval.now(G.pure(ArraySeq.untagged.empty[A]))) { case (x, xse) =>
G.map2Eval(f(x), xse)((b, vec) => if (b) x +: vec else vec)
}.value

override def mapAccumulateFilter[S, A, B](init: S, fa: ArraySeq[A])(
f: (S, A) => (S, Option[B])
): (S, ArraySeq[B]) =
StaticMethods.mapAccumulateFilterFromStrictFunctorFilter(init, fa, f)(this)
}
}
17 changes: 17 additions & 0 deletions core/src/main/scala/cats/TraverseFilter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,21 @@ trait TraverseFilter[F[_]] extends FunctorFilter[F] {
override def mapFilter[A, B](fa: F[A])(f: A => Option[B]): F[B] =
traverseFilter[Id, A, B](fa)(f)

/**
* Like [[mapAccumulate]], but allows `Option` in supplied accumulating function,
* keeping only `Some`s.
*
* Example:
* {{{
* scala> import cats.syntax.all._
* scala> val sumAllAndKeepOdd = (s: Int, n: Int) => (s + n, Option.when(n % 2 == 1)(n))
* scala> List(1, 2, 3, 4).mapAccumulateFilter(0, sumAllAndKeepOdd)
* res1: (Int, List[Int]) = (10, List(1, 3))
* }}}
*/
def mapAccumulateFilter[S, A, B](init: S, fa: F[A])(f: (S, A) => (S, Option[B])): (S, F[B]) =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that a scaladoc comment (perhaps, with a short usage example) wouldn't hurt and could come handy here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost done, I should fix the code violations. Is the language correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks!

traverseFilter(fa)(a => State(s => f(s, a))).run(init).value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we override this for some of the built in collections? State is rather slow so we should avoid it for List, Vector, Chain, NonEmptyList, NonEmptyVector, ...

Copy link
Contributor Author

@Masynchin Masynchin Feb 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can copy StaticMethods.mapAccumulateFromStrictFunctor for mapAccumulateFilter, this will cover for some collections


/**
* Removes duplicate elements from a list, keeping only the first occurrence.
*/
Expand Down Expand Up @@ -184,6 +199,8 @@ object TraverseFilter {
typeClassInstance.filterA[G, A](self)(f)(G)
def traverseEither[G[_], B, C](f: A => G[Either[C, B]])(g: (A, C) => G[Unit])(implicit G: Monad[G]): G[F[B]] =
typeClassInstance.traverseEither[G, A, B, C](self)(f)(g)(G)
def mapAccumulateFilter[S, B](init: S)(f: (S, A) => (S, Option[B])): (S, F[B]) =
typeClassInstance.mapAccumulateFilter[S, A, B](init, self)(f)
def ordDistinct(implicit O: Order[A]): F[A] = typeClassInstance.ordDistinct(self)
def hashDistinct(implicit H: Hash[A]): F[A] = typeClassInstance.hashDistinct(self)
}
Expand Down
3 changes: 3 additions & 0 deletions core/src/main/scala/cats/data/Chain.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1369,6 +1369,9 @@ sealed abstract private[data] class ChainInstances extends ChainInstances1 {
)
.value

override def mapAccumulateFilter[S, A, B](init: S, fa: Chain[A])(f: (S, A) => (S, Option[B])): (S, Chain[B]) =
StaticMethods.mapAccumulateFilterFromStrictFunctorFilter(init, fa, f)(this)

}

private[this] val theMonoid: Monoid[Chain[Any]] = new Monoid[Chain[Any]] {
Expand Down
16 changes: 15 additions & 1 deletion core/src/main/scala/cats/instances/StaticMethods.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

package cats.instances

import cats.Functor
import cats.{Functor, FunctorFilter}

import scala.collection.mutable.Builder

Expand Down Expand Up @@ -70,4 +70,18 @@ private[cats] object StaticMethods {
}
}

def mapAccumulateFilterFromStrictFunctorFilter[S, F[_], A, B](init: S, fa: F[A], f: (S, A) => (S, Option[B]))(implicit
ev: FunctorFilter[F]
): (S, F[B]) = {
var state = init

val fb = ev.mapFilter(fa) { a =>
val (newState, b) = f(state, a)
state = newState
b
}

(state, fb)
}

}
3 changes: 3 additions & 0 deletions core/src/main/scala/cats/instances/list.scala
Original file line number Diff line number Diff line change
Expand Up @@ -322,5 +322,8 @@ private[instances] trait ListInstancesBinCompat0 {
G.map2Eval(f(x), xse)((b, list) => if (b) x :: list else list)
)
.value

override def mapAccumulateFilter[S, A, B](init: S, fa: List[A])(f: (S, A) => (S, Option[B])): (S, List[B]) =
StaticMethods.mapAccumulateFilterFromStrictFunctorFilter(init, fa, f)(this)
}
}
6 changes: 6 additions & 0 deletions core/src/main/scala/cats/instances/option.scala
Original file line number Diff line number Diff line change
Expand Up @@ -285,5 +285,11 @@ private[instances] trait OptionInstancesBinCompat0 {
case Some(a) => G.map(f(a))(b => if (b) Some(a) else None)
}

override def mapAccumulateFilter[S, A, B](init: S, fa: Option[A])(f: (S, A) => (S, Option[B])): (S, Option[B]) =
fa match {
case Some(a) => f(init, a)
case None => (init, None)
}

}
}
3 changes: 3 additions & 0 deletions core/src/main/scala/cats/instances/queue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -237,5 +237,8 @@ private object QueueInstances {
G.map2Eval(f(x), xse)((b, queue) => if (b) x +: queue else queue)
)
.value

override def mapAccumulateFilter[S, A, B](init: S, fa: Queue[A])(f: (S, A) => (S, Option[B])): (S, Queue[B]) =
StaticMethods.mapAccumulateFilterFromStrictFunctorFilter(init, fa, f)(this)
}
}
3 changes: 3 additions & 0 deletions core/src/main/scala/cats/instances/vector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -271,5 +271,8 @@ private[instances] trait VectorInstancesBinCompat0 {
G.map2Eval(f(x), xse)((b, vector) => if (b) x +: vector else vector)
)
.value

override def mapAccumulateFilter[S, A, B](init: S, fa: Vector[A])(f: (S, A) => (S, Option[B])): (S, Vector[B]) =
StaticMethods.mapAccumulateFilterFromStrictFunctorFilter(init, fa, f)(this)
}
}
13 changes: 13 additions & 0 deletions tests/shared/src/test/scala/cats/tests/TraverseFilterSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@ abstract class TraverseFilterSuite[F[_]: TraverseFilter](name: String)(implicit

implicit def T: Traverse[F] = implicitly[TraverseFilter[F]].traverse

test(s"TraverseFilter[$name].mapAccumulateFilter") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it does not check the default implementation (based on State), does it?
Except maybe one for Stream, but I wouldn't count on it.

For testing default implementations we usually use ListWrapper from testkit:
ListWrapper.scala

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had straightforwardly copy-pasted and updated tests from mapAccumulate here:

test(s"Traverse[$name].mapAccumulate") {
forAll { (init: Int, fa: F[Int], fn: ((Int, Int)) => (Int, Int)) =>
val lhs = fa.mapAccumulate(init)((s, a) => fn((s, a)))
val rhs = fa.foldLeft((init, List.empty[Int])) { case ((s1, acc), a) =>
val (s2, b) = fn((s1, a))
(s2, b :: acc)
}
assert(lhs.map(_.toList) === rhs.map(_.reverse))
}
}

If that tests doesn't test default implementation too, I can update mapAccumulateFilter tests. Can you provide an example of how to pass ListWrapper?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For testing default implementations we usually use ListWrapper from testkit: ListWrapper.scala

I am failing to understand how to use ListWrapper to test mapAccumulateFilter. I was looking for examples in other Suites, but it is either suites for data and not typeclasses (OptionT, Try, etc.), or it refers to <Typeclass>Tests[ListWrapper].<methodToTest>, like in the ApplicativeSuite:

implicit val listwrapperApplicative: Applicative[ListWrapper] = ListWrapper.applicative
implicit val listwrapperCoflatMap: CoflatMap[ListWrapper] = Applicative.coflatMap[ListWrapper]
checkAll("Applicative[ListWrapper].coflatMap", CoflatMapTests[ListWrapper].coflatMap[String, String, String])

which I can not apply here, because we don't have TraverseFilter[?].mapAccumulateFilter. Am I missing something?

Copy link
Contributor

@satorg satorg Apr 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My apologies for the delay – I was snowed under a bit. Actually, there's TraverseFilter for ListWrapper:

val traverseFilter: TraverseFilter[ListWrapper] = {
val F = TraverseFilter[List]
new TraverseFilter[ListWrapper] {
def traverse = ListWrapper.traverse
def traverseFilter[G[_], A, B](
fa: ListWrapper[A]
)(f: A => G[Option[B]])(implicit G: Applicative[G]): G[ListWrapper[B]] =
G.map(F.traverseFilter(fa.list)(f))(ListWrapper.apply)
}
}

To test the default implementation you can either call it directly:

ListWrapper.traverseFilter.mapAccumulateFilter(...)

or make it an implicit in the scope:

implicit val listWrapperTraverseFilter: TraverseFilter[ListWrapper] = ListWrapper.traverseFilter

And then you can work with TraverseFilter for ListWrapper as usual.

forAll { (init: Int, fa: F[Int], fn: ((Int, Int)) => (Int, Option[Int])) =>
val lhs = fa.mapAccumulateFilter(init)((s, a) => fn((s, a)))

val rhs = fa.foldLeft((init, List.empty[Int])) { case ((s1, acc), a) =>
val (s2, b) = fn((s1, a))
(s2, b.fold(acc)(_ :: acc))
}

assert(lhs.map(_.toList) === rhs.map(_.reverse))
}
}

test(s"TraverseFilter[$name].ordDistinct") {
forAll { (fa: F[Int]) =>
fa.ordDistinct.toList === fa.toList.distinct
Expand Down
Loading