From 4e146de4e288d21c0c0cd8a4c6acf5abf99fb1f4 Mon Sep 17 00:00:00 2001 From: Daniil Filippov Date: Wed, 20 Dec 2023 10:35:53 +0300 Subject: [PATCH] Avoid double evaluation of pattern matchers in Chunk.collect --- .../scala/fs2/benchmark/ChunkBenchmark.scala | 10 ++++++ core/shared/src/main/scala/fs2/Chunk.scala | 3 +- .../src/test/scala/fs2/ChunkSuite.scala | 31 +++++++++++++++++++ 3 files changed, 43 insertions(+), 1 deletion(-) diff --git a/benchmark/src/main/scala/fs2/benchmark/ChunkBenchmark.scala b/benchmark/src/main/scala/fs2/benchmark/ChunkBenchmark.scala index b6f547811d..6c97b50557 100644 --- a/benchmark/src/main/scala/fs2/benchmark/ChunkBenchmark.scala +++ b/benchmark/src/main/scala/fs2/benchmark/ChunkBenchmark.scala @@ -46,4 +46,14 @@ class ChunkBenchmark { ints.filter(_ % 3 == 0) () } + + private object OddStringExtractor { + def unapply(i: Int): Option[String] = if (i % 2 != 0) Some(i.toString) else None + } + + @Benchmark + def collect(): Unit = { + ints.collect { case OddStringExtractor(s) => s } + () + } } diff --git a/core/shared/src/main/scala/fs2/Chunk.scala b/core/shared/src/main/scala/fs2/Chunk.scala index c29e7fd3fe..a49d79a4fe 100644 --- a/core/shared/src/main/scala/fs2/Chunk.scala +++ b/core/shared/src/main/scala/fs2/Chunk.scala @@ -87,7 +87,8 @@ abstract class Chunk[+O] extends Serializable with ChunkPlatform[O] with ChunkRu def collect[O2](pf: PartialFunction[O, O2]): Chunk[O2] = { val b = makeArrayBuilder[Any] b.sizeHint(size) - foreach(o => if (pf.isDefinedAt(o)) b += pf(o)) + val f = pf.runWith(b += _) + foreach { o => f(o); () } Chunk.array(b.result()).asInstanceOf[Chunk[O2]] } diff --git a/core/shared/src/test/scala/fs2/ChunkSuite.scala b/core/shared/src/test/scala/fs2/ChunkSuite.scala index 8685f6fe4b..f76f60cb38 100644 --- a/core/shared/src/test/scala/fs2/ChunkSuite.scala +++ b/core/shared/src/test/scala/fs2/ChunkSuite.scala @@ -31,6 +31,7 @@ import scodec.bits.ByteVector import java.nio.ByteBuffer import java.nio.CharBuffer +import java.util.concurrent.atomic.AtomicInteger import scala.reflect.ClassTag class ChunkSuite extends Fs2Suite { @@ -108,6 +109,36 @@ class ChunkSuite extends Fs2Suite { assert(Chunk.javaList(c.asJava) eq c) } } + + test("Chunk.collect behaves as filter + map") { + forAll { (c: Chunk[Int]) => + val extractor = new OddStringExtractor + val pf: PartialFunction[Int, String] = { case extractor(s) => s } + + val result = c.collect(pf) + + assertEquals(result, c.filter(pf.isDefinedAt).map(pf)) + } + } + + test("Chunk.collect evaluates pattern matchers once per item") { + forAll { (c: Chunk[Int]) => + val extractor = new OddStringExtractor + + val _ = c.collect { case extractor(s) => s } + + assertEquals(extractor.callCounter.get(), c.size) + } + } + + class OddStringExtractor { + val callCounter: AtomicInteger = new AtomicInteger(0) + + def unapply(i: Int): Option[String] = { + callCounter.incrementAndGet() + if (i % 2 != 0) Some(i.toString) else None + } + } } def testChunk[A: Arbitrary: ClassTag: CommutativeMonoid: Eq: Cogen](