diff --git a/core/src/main/scala/cats/Traverse.scala b/core/src/main/scala/cats/Traverse.scala index 24015db6b9..4b6f86f534 100644 --- a/core/src/main/scala/cats/Traverse.scala +++ b/core/src/main/scala/cats/Traverse.scala @@ -1,5 +1,8 @@ package cats +import cats.data.State +import cats.data.StateT + import simulacrum.typeclass /** @@ -97,4 +100,34 @@ import simulacrum.typeclass override def map[A, B](fa: F[A])(f: A => B): F[B] = traverse[Id, A, B](fa)(f) + + /** + * Akin to [[map]], but also provides the value's index in structure + * F when calling the function. + */ + def mapWithIndex[A, B](fa: F[A])(f: (A, Int) => B): F[B] = + traverse(fa)(a => + State((s: Int) => (s + 1, f(a, s)))).runA(0).value + + /** + * Akin to [[traverse]], but also provides the value's index in + * structure F when calling the function. + * + * This performs the traversal in a single pass but requires that + * effect G is monadic. An applicative traveral can be performed in + * two passes using [[zipWithIndex]] followed by [[traverse]]. + */ + def traverseWithIndexM[G[_], A, B](fa: F[A])(f: (A, Int) => G[B])(implicit G: Monad[G]): G[F[B]] = + traverse(fa)(a => + StateT((s: Int) => G.map(f(a, s))(b => (s + 1, b)))).runA(0) + + /** + * Traverses through the structure F, pairing the values with + * assigned indices. + * + * The behavior is consistent with the Scala collection library's + * `zipWithIndex` for collections such as `List`. + */ + def zipWithIndex[A](fa: F[A]): F[(A, Int)] = + mapWithIndex(fa)((a, i) => (a, i)) } diff --git a/core/src/main/scala/cats/instances/list.scala b/core/src/main/scala/cats/instances/list.scala index 9cb87d619c..eb203aed25 100644 --- a/core/src/main/scala/cats/instances/list.scala +++ b/core/src/main/scala/cats/instances/list.scala @@ -70,6 +70,12 @@ trait ListInstances extends cats.kernel.instances.ListInstances { G.map2Eval(f(a), lglb)(_ :: _) }.value + override def mapWithIndex[A, B](fa: List[A])(f: (A, Int) => B): List[B] = + fa.iterator.zipWithIndex.map(ai => f(ai._1, ai._2)).toList + + override def zipWithIndex[A](fa: List[A]): List[(A, Int)] = + fa.zipWithIndex + @tailrec override def get[A](fa: List[A])(idx: Long): Option[A] = fa match { diff --git a/core/src/main/scala/cats/instances/stream.scala b/core/src/main/scala/cats/instances/stream.scala index 4af61f8144..8ade719625 100644 --- a/core/src/main/scala/cats/instances/stream.scala +++ b/core/src/main/scala/cats/instances/stream.scala @@ -55,6 +55,12 @@ trait StreamInstances extends cats.kernel.instances.StreamInstances { }.value } + override def mapWithIndex[A, B](fa: Stream[A])(f: (A, Int) => B): Stream[B] = + fa.zipWithIndex.map(ai => f(ai._1, ai._2)) + + override def zipWithIndex[A](fa: Stream[A]): Stream[(A, Int)] = + fa.zipWithIndex + def tailRecM[A, B](a: A)(fn: A => Stream[Either[A, B]]): Stream[B] = { val it: Iterator[B] = new Iterator[B] { var stack: Stream[Either[A, B]] = fn(a) diff --git a/core/src/main/scala/cats/instances/vector.scala b/core/src/main/scala/cats/instances/vector.scala index caeed44fb8..5c466aa76c 100644 --- a/core/src/main/scala/cats/instances/vector.scala +++ b/core/src/main/scala/cats/instances/vector.scala @@ -82,6 +82,12 @@ trait VectorInstances extends cats.kernel.instances.VectorInstances { G.map2Eval(f(a), lgvb)(_ +: _) }.value + override def mapWithIndex[A, B](fa: Vector[A])(f: (A, Int) => B): Vector[B] = + fa.iterator.zipWithIndex.map(ai => f(ai._1, ai._2)).toVector + + override def zipWithIndex[A](fa: Vector[A]): Vector[(A, Int)] = + fa.zipWithIndex + override def exists[A](fa: Vector[A])(p: A => Boolean): Boolean = fa.exists(p) diff --git a/tests/src/test/scala/cats/tests/TraverseTests.scala b/tests/src/test/scala/cats/tests/TraverseTests.scala new file mode 100644 index 0000000000..de48c339a1 --- /dev/null +++ b/tests/src/test/scala/cats/tests/TraverseTests.scala @@ -0,0 +1,76 @@ +package cats +package tests + +import org.scalatest.prop.PropertyChecks +import org.scalacheck.Arbitrary + +import cats.instances.all._ + +abstract class TraverseCheck[F[_]: Traverse](name: String)(implicit ArbFInt: Arbitrary[F[Int]]) extends CatsSuite with PropertyChecks { + + test(s"Traverse[$name].zipWithIndex") { + forAll { (fa: F[Int]) => + fa.zipWithIndex.toList should === (fa.toList.zipWithIndex) + } + } + + test(s"Traverse[$name].mapWithIndex") { + forAll { (fa: F[Int], fn: ((Int, Int)) => Int) => + fa.mapWithIndex((a, i) => fn((a, i))).toList should === (fa.toList.zipWithIndex.map(fn)) + } + } + + test(s"Traverse[$name].traverseWithIndexM") { + forAll { (fa: F[Int], fn: ((Int, Int)) => (Int, Int)) => + val left = fa.traverseWithIndexM((a, i) => fn((a, i))).map(_.toList) + val (xs, values) = fa.toList.zipWithIndex.map(fn).unzip + left should === ((xs.combineAll, values)) + } + } + +} + +object TraverseCheck { + // forces testing of the underlying implementation (avoids overridden methods) + abstract class Underlying[F[_]: Traverse](name: String)(implicit ArbFInt: Arbitrary[F[Int]]) + extends TraverseCheck(s"$name (underlying)")(proxyTraverse[F], ArbFInt) + + // proxies a traverse instance so we can test default implementations + // to achieve coverage using default datatype instances + private def proxyTraverse[F[_]: Traverse]: Traverse[F] = new Traverse[F] { + def foldLeft[A, B](fa: F[A], b: B)(f: (B, A) => B): B = + Traverse[F].foldLeft(fa, b)(f) + def foldRight[A, B](fa: F[A], lb: cats.Eval[B])(f: (A, Eval[B]) => Eval[B]): Eval[B] = + Traverse[F].foldRight(fa, lb)(f) + def traverse[G[_]: Applicative, A, B](fa: F[A])(f: A => G[B]): G[F[B]] = + Traverse[F].traverse(fa)(f) + } +} + +class TraverseListCheck extends TraverseCheck[List]("List") +class TraverseStreamCheck extends TraverseCheck[Stream]("Stream") +class TraverseVectorCheck extends TraverseCheck[Vector]("Vector") + +class TraverseListCheckUnderlying extends TraverseCheck.Underlying[List]("List") +class TraverseStreamCheckUnderlying extends TraverseCheck.Underlying[Stream]("Stream") +class TraverseVectorCheckUnderlying extends TraverseCheck.Underlying[Vector]("Vector") + +class TraverseTestsAdditional extends CatsSuite { + + def checkZipWithIndexedStackSafety[F[_]](fromRange: Range => F[Int])(implicit F: Traverse[F]): Unit = { + F.zipWithIndex(fromRange(1 to 70000)) + () + } + + test("Traverse[List].zipWithIndex stack safety") { + checkZipWithIndexedStackSafety[List](_.toList) + } + + test("Traverse[Stream].zipWithIndex stack safety") { + checkZipWithIndexedStackSafety[Stream](_.toStream) + } + + test("Traverse[Vector].zipWithIndex stack safety") { + checkZipWithIndexedStackSafety[Vector](_.toVector) + } +}